diff --git a/.ci/scripts/auditwheel_wrapper.py b/.ci/scripts/auditwheel_wrapper.py index 9599b79e50..9832821221 100755 --- a/.ci/scripts/auditwheel_wrapper.py +++ b/.ci/scripts/auditwheel_wrapper.py @@ -25,7 +25,6 @@ import argparse import os import subprocess -from typing import Optional from zipfile import ZipFile from packaging.tags import Tag @@ -80,7 +79,7 @@ def cpython(wheel_file: str, name: str, version: Version, tag: Tag) -> str: return new_wheel_file -def main(wheel_file: str, dest_dir: str, archs: Optional[str]) -> None: +def main(wheel_file: str, dest_dir: str, archs: str | None) -> None: """Entry point""" # Parse the wheel file name into its parts. Note that `parse_wheel_filename` diff --git a/changelog.d/19111.misc b/changelog.d/19111.misc new file mode 100644 index 0000000000..cb4ca85c47 --- /dev/null +++ b/changelog.d/19111.misc @@ -0,0 +1 @@ +Write union types as `X | Y` where possible, as per PEP 604, added in Python 3.10. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 9b5d33d2b1..1c867d5336 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -33,7 +33,6 @@ import sys import time import urllib from http import TwistedHttpClient -from typing import Optional import urlparse from signedjson.key import NACL_ED25519, decode_verify_key_bytes @@ -726,7 +725,7 @@ class SynapseCmd(cmd.Cmd): method, path, data=None, - query_params: Optional[dict] = None, + query_params: dict | None = None, alt_text=None, ): """Runs an HTTP request and pretty prints the output. diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 54363e4259..b92ccdd932 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -22,7 +22,6 @@ import json import urllib from pprint import pformat -from typing import Optional from twisted.internet import defer, reactor from twisted.web.client import Agent, readBody @@ -90,7 +89,7 @@ class TwistedHttpClient(HttpClient): body = yield readBody(response) return json.loads(body) - def _create_put_request(self, url, json_data, headers_dict: Optional[dict] = None): + def _create_put_request(self, url, json_data, headers_dict: dict | None = None): """Wrapper of _create_request to issue a PUT request""" headers_dict = headers_dict or {} @@ -101,7 +100,7 @@ class TwistedHttpClient(HttpClient): "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict ) - def _create_get_request(self, url, headers_dict: Optional[dict] = None): + def _create_get_request(self, url, headers_dict: dict | None = None): """Wrapper of _create_request to issue a GET request""" return self._create_request("GET", url, headers_dict=headers_dict or {}) @@ -113,7 +112,7 @@ class TwistedHttpClient(HttpClient): data=None, qparams=None, jsonreq=True, - headers: Optional[dict] = None, + headers: dict | None = None, ): headers = headers or {} @@ -138,7 +137,7 @@ class TwistedHttpClient(HttpClient): @defer.inlineCallbacks def _create_request( - self, method, url, producer=None, headers_dict: Optional[dict] = None + self, method, url, producer=None, headers_dict: dict | None = None ): """Creates and sends a request to the given url""" headers_dict = headers_dict or {} diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 2451d1f300..e19b0a0039 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -68,7 +68,6 @@ from typing import ( Mapping, MutableMapping, NoReturn, - Optional, SupportsIndex, ) @@ -468,7 +467,7 @@ def add_worker_roles_to_shared_config( def merge_worker_template_configs( - existing_dict: Optional[dict[str, Any]], + existing_dict: dict[str, Any] | None, to_be_merged_dict: dict[str, Any], ) -> dict[str, Any]: """When given an existing dict of worker template configuration consisting with both @@ -1026,7 +1025,7 @@ def generate_worker_log_config( Returns: the path to the generated file """ # Check whether we should write worker logs to disk, in addition to the console - extra_log_template_args: dict[str, Optional[str]] = {} + extra_log_template_args: dict[str, str | None] = {} if environ.get("SYNAPSE_WORKERS_WRITE_LOGS_TO_DISK"): extra_log_template_args["LOG_FILE_PATH"] = f"{data_dir}/logs/{worker_name}.log" diff --git a/docker/start.py b/docker/start.py index daa041d463..c88d23695f 100755 --- a/docker/start.py +++ b/docker/start.py @@ -6,7 +6,7 @@ import os import platform import subprocess import sys -from typing import Any, Mapping, MutableMapping, NoReturn, Optional +from typing import Any, Mapping, MutableMapping, NoReturn import jinja2 @@ -50,7 +50,7 @@ def generate_config_from_template( config_dir: str, config_path: str, os_environ: Mapping[str, str], - ownership: Optional[str], + ownership: str | None, ) -> None: """Generate a homeserver.yaml from environment variables @@ -147,7 +147,7 @@ def generate_config_from_template( subprocess.run(args, check=True) -def run_generate_config(environ: Mapping[str, str], ownership: Optional[str]) -> None: +def run_generate_config(environ: Mapping[str, str], ownership: str | None) -> None: """Run synapse with a --generate-config param to generate a template config file Args: diff --git a/docs/development/synapse_architecture/cancellation.md b/docs/development/synapse_architecture/cancellation.md index ef9e022635..a12f119fb5 100644 --- a/docs/development/synapse_architecture/cancellation.md +++ b/docs/development/synapse_architecture/cancellation.md @@ -299,7 +299,7 @@ logcontext is not finished before the `async` processing completes. **Bad**: ```python -cache: Optional[ObservableDeferred[None]] = None +cache: ObservableDeferred[None] | None = None async def do_something_else( to_resolve: Deferred[None] @@ -326,7 +326,7 @@ with LoggingContext("request-1"): **Good**: ```python -cache: Optional[ObservableDeferred[None]] = None +cache: ObservableDeferred[None] | None = None async def do_something_else( to_resolve: Deferred[None] @@ -358,7 +358,7 @@ with LoggingContext("request-1"): **OK**: ```python -cache: Optional[ObservableDeferred[None]] = None +cache: ObservableDeferred[None] | None = None async def do_something_else( to_resolve: Deferred[None] diff --git a/docs/modules/account_data_callbacks.md b/docs/modules/account_data_callbacks.md index 25de911627..02b8c18bbf 100644 --- a/docs/modules/account_data_callbacks.md +++ b/docs/modules/account_data_callbacks.md @@ -15,7 +15,7 @@ _First introduced in Synapse v1.57.0_ ```python async def on_account_data_updated( user_id: str, - room_id: Optional[str], + room_id: str | None, account_data_type: str, content: "synapse.module_api.JsonDict", ) -> None: @@ -82,7 +82,7 @@ class CustomAccountDataModule: async def log_new_account_data( self, user_id: str, - room_id: Optional[str], + room_id: str | None, account_data_type: str, content: JsonDict, ) -> None: diff --git a/docs/modules/account_validity_callbacks.md b/docs/modules/account_validity_callbacks.md index f5eefcd7d6..2deb43c1be 100644 --- a/docs/modules/account_validity_callbacks.md +++ b/docs/modules/account_validity_callbacks.md @@ -12,7 +12,7 @@ The available account validity callbacks are: _First introduced in Synapse v1.39.0_ ```python -async def is_user_expired(user: str) -> Optional[bool] +async def is_user_expired(user: str) -> bool | None ``` Called when processing any authenticated request (except for logout requests). The module diff --git a/docs/modules/media_repository_callbacks.md b/docs/modules/media_repository_callbacks.md index 7c724038a7..d7c9074bde 100644 --- a/docs/modules/media_repository_callbacks.md +++ b/docs/modules/media_repository_callbacks.md @@ -11,7 +11,7 @@ The available media repository callbacks are: _First introduced in Synapse v1.132.0_ ```python -async def get_media_config_for_user(user_id: str) -> Optional[JsonDict] +async def get_media_config_for_user(user_id: str) -> JsonDict | None ``` ** @@ -70,7 +70,7 @@ implementations of this callback. _First introduced in Synapse v1.139.0_ ```python -async def get_media_upload_limits_for_user(user_id: str, size: int) -> Optional[List[synapse.module_api.MediaUploadLimit]] +async def get_media_upload_limits_for_user(user_id: str, size: int) -> list[synapse.module_api.MediaUploadLimit] | None ``` ** diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index d66ac7df31..88b22fdf21 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -23,12 +23,7 @@ async def check_auth( user: str, login_type: str, login_dict: "synapse.module_api.JsonDict", -) -> Optional[ - Tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] - ] -] +) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None ``` The login type and field names should be provided by the user in the @@ -67,12 +62,7 @@ async def check_3pid_auth( medium: str, address: str, password: str, -) -> Optional[ - Tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] - ] -] +) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] ``` Called when a user attempts to register or log in with a third party identifier, @@ -98,7 +88,7 @@ _First introduced in Synapse v1.46.0_ ```python async def on_logged_out( user_id: str, - device_id: Optional[str], + device_id: str | None, access_token: str ) -> None ``` @@ -119,7 +109,7 @@ _First introduced in Synapse v1.52.0_ async def get_username_for_registration( uia_results: Dict[str, Any], params: Dict[str, Any], -) -> Optional[str] +) -> str | None ``` Called when registering a new user. The module can return a username to set for the user @@ -180,7 +170,7 @@ _First introduced in Synapse v1.54.0_ async def get_displayname_for_registration( uia_results: Dict[str, Any], params: Dict[str, Any], -) -> Optional[str] +) -> str | None ``` Called when registering a new user. The module can return a display name to set for the @@ -259,12 +249,7 @@ class MyAuthProvider: username: str, login_type: str, login_dict: "synapse.module_api.JsonDict", - ) -> Optional[ - Tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], - ] - ]: + ) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None: if login_type != "my.login_type": return None @@ -276,12 +261,7 @@ class MyAuthProvider: username: str, login_type: str, login_dict: "synapse.module_api.JsonDict", - ) -> Optional[ - Tuple[ - str, - Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], - ] - ]: + ) -> tuple[str, Callable[["synapse.module_api.LoginResponse"], Awaitable[None]] | None] | None: if login_type != "m.login.password": return None diff --git a/docs/modules/presence_router_callbacks.md b/docs/modules/presence_router_callbacks.md index b210f0e3cd..f865e79f53 100644 --- a/docs/modules/presence_router_callbacks.md +++ b/docs/modules/presence_router_callbacks.md @@ -23,7 +23,7 @@ _First introduced in Synapse v1.42.0_ ```python async def get_users_for_states( state_updates: Iterable["synapse.api.UserPresenceState"], -) -> Dict[str, Set["synapse.api.UserPresenceState"]] +) -> dict[str, set["synapse.api.UserPresenceState"]] ``` **Requires** `get_interested_users` to also be registered @@ -45,7 +45,7 @@ _First introduced in Synapse v1.42.0_ ```python async def get_interested_users( user_id: str -) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"] +) -> set[str] | "synapse.module_api.PRESENCE_ALL_USERS" ``` **Requires** `get_users_for_states` to also be registered @@ -73,7 +73,7 @@ that `@alice:example.org` receives all presence updates from `@bob:example.com` `@charlie:somewhere.org`, regardless of whether Alice shares a room with any of them. ```python -from typing import Dict, Iterable, Set, Union +from typing import Iterable from synapse.module_api import ModuleApi @@ -90,7 +90,7 @@ class CustomPresenceRouter: async def get_users_for_states( self, state_updates: Iterable["synapse.api.UserPresenceState"], - ) -> Dict[str, Set["synapse.api.UserPresenceState"]]: + ) -> dict[str, set["synapse.api.UserPresenceState"]]: res = {} for update in state_updates: if ( @@ -104,7 +104,7 @@ class CustomPresenceRouter: async def get_interested_users( self, user_id: str, - ) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"]: + ) -> set[str] | "synapse.module_api.PRESENCE_ALL_USERS": if user_id == "@alice:example.com": return {"@bob:example.com", "@charlie:somewhere.org"} diff --git a/docs/modules/ratelimit_callbacks.md b/docs/modules/ratelimit_callbacks.md index 30d94024fa..048bdc6f9e 100644 --- a/docs/modules/ratelimit_callbacks.md +++ b/docs/modules/ratelimit_callbacks.md @@ -11,7 +11,7 @@ The available ratelimit callbacks are: _First introduced in Synapse v1.132.0_ ```python -async def get_ratelimit_override_for_user(user: str, limiter_name: str) -> Optional[synapse.module_api.RatelimitOverride] +async def get_ratelimit_override_for_user(user: str, limiter_name: str) -> synapse.module_api.RatelimitOverride | None ``` ** diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 0f15a9dcc5..0d261e844f 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -331,9 +331,9 @@ search results; otherwise return `False`. The profile is represented as a dictionary with the following keys: * `user_id: str`. The Matrix ID for this user. -* `display_name: Optional[str]`. The user's display name, or `None` if this user +* `display_name: str | None`. The user's display name, or `None` if this user has not set a display name. -* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None` +* `avatar_url: str | None`. The `mxc://` URL to the user's avatar, or `None` if this user has not set an avatar. The module is given a copy of the original dictionary, so modifying it from within the @@ -352,10 +352,10 @@ _First introduced in Synapse v1.37.0_ ```python async def check_registration_for_spam( - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[Tuple[str, str]], - auth_provider_id: Optional[str] = None, + auth_provider_id: str | None = None, ) -> "synapse.spam_checker_api.RegistrationBehaviour" ``` @@ -438,10 +438,10 @@ _First introduced in Synapse v1.87.0_ ```python async def check_login_for_spam( user_id: str, - device_id: Optional[str], - initial_display_name: Optional[str], - request_info: Collection[Tuple[Optional[str], str]], - auth_provider_id: Optional[str] = None, + device_id: str | None, + initial_display_name: str | None, + request_info: Collection[tuple[str | None, str]], + auth_provider_id: str | None = None, ) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"] ``` @@ -509,7 +509,7 @@ class ListSpamChecker: resource=IsUserEvilResource(config), ) - async def check_event_for_spam(self, event: "synapse.events.EventBase") -> Union[Literal["NOT_SPAM"], Codes]: + async def check_event_for_spam(self, event: "synapse.events.EventBase") -> Literal["NOT_SPAM"] | Codes: if event.sender in self.evil_users: return Codes.FORBIDDEN else: diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index b97e28db11..1474b2dfd5 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -16,7 +16,7 @@ _First introduced in Synapse v1.39.0_ async def check_event_allowed( event: "synapse.events.EventBase", state_events: "synapse.types.StateMap", -) -> Tuple[bool, Optional[dict]] +) -> tuple[bool, dict | None] ``` ** @@ -340,7 +340,7 @@ class EventCensorer: self, event: "synapse.events.EventBase", state_events: "synapse.types.StateMap", - ) -> Tuple[bool, Optional[dict]]: + ) -> Tuple[bool, dict | None]: event_dict = event.get_dict() new_event_content = await self.api.http_client.post_json_get_json( uri=self._endpoint, post_json=event_dict, diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md index face54fe2b..092b566c5f 100644 --- a/docs/presence_router_module.md +++ b/docs/presence_router_module.md @@ -76,7 +76,7 @@ possible. #### `get_interested_users` ```python -async def get_interested_users(self, user_id: str) -> Union[Set[str], str] +async def get_interested_users(self, user_id: str) -> set[str] | str ``` **Required.** An asynchronous method that is passed a single Matrix User ID. This @@ -182,7 +182,7 @@ class ExamplePresenceRouter: async def get_interested_users( self, user_id: str, - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + ) -> set[str] | PresenceRouter.ALL_USERS: """ Retrieve a list of users that `user_id` is interested in receiving the presence of. This will be in addition to those they share a room with. diff --git a/pyproject.toml b/pyproject.toml index 991cb3e7f3..0eef197cf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,10 +80,15 @@ select = [ "G", # pyupgrade "UP006", + "UP007", + "UP045", ] extend-safe-fixes = [ - # pyupgrade - "UP006" + # pyupgrade rules compatible with Python >= 3.9 + "UP006", + "UP007", + # pyupgrade rules compatible with Python >= 3.10 + "UP045", ] [tool.ruff.lint.isort] diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py index 60aa8a5796..d462fe6c56 100755 --- a/scripts-dev/build_debian_packages.py +++ b/scripts-dev/build_debian_packages.py @@ -18,7 +18,7 @@ import sys import threading from concurrent.futures import ThreadPoolExecutor from types import FrameType -from typing import Collection, Optional, Sequence +from typing import Collection, Sequence # These are expanded inside the dockerfile to be a fully qualified image name. # e.g. docker.io/library/debian:bookworm @@ -49,7 +49,7 @@ class Builder: def __init__( self, redirect_stdout: bool = False, - docker_build_args: Optional[Sequence[str]] = None, + docker_build_args: Sequence[str] | None = None, ): self.redirect_stdout = redirect_stdout self._docker_build_args = tuple(docker_build_args or ()) @@ -167,7 +167,7 @@ class Builder: def run_builds( builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False ) -> None: - def sig(signum: int, _frame: Optional[FrameType]) -> None: + def sig(signum: int, _frame: FrameType | None) -> None: print("Caught SIGINT") builder.kill_containers() diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index db8655c1ce..0fefc23b22 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -43,7 +43,7 @@ import argparse import base64 import json import sys -from typing import Any, Mapping, Optional, Union +from typing import Any, Mapping from urllib import parse as urlparse import requests @@ -103,12 +103,12 @@ def sign_json( def request( - method: Optional[str], + method: str | None, origin_name: str, origin_key: signedjson.types.SigningKey, destination: str, path: str, - content: Optional[str], + content: str | None, verify_tls: bool, ) -> requests.Response: if method is None: @@ -301,9 +301,9 @@ class MatrixConnectionAdapter(HTTPAdapter): def get_connection_with_tls_context( self, request: PreparedRequest, - verify: Optional[Union[bool, str]], - proxies: Optional[Mapping[str, str]] = None, - cert: Optional[Union[tuple[str, str], str]] = None, + verify: bool | str | None, + proxies: Mapping[str, str] | None = None, + cert: tuple[str, str] | str | None = None, ) -> HTTPConnectionPool: # overrides the get_connection_with_tls_context() method in the base class parsed = urlparse.urlsplit(request.url) @@ -368,7 +368,7 @@ class MatrixConnectionAdapter(HTTPAdapter): return server_name, 8448, server_name @staticmethod - def _get_well_known(server_name: str) -> Optional[str]: + def _get_well_known(server_name: str) -> str | None: if ":" in server_name: # explicit port, or ipv6 literal. Either way, no .well-known return None diff --git a/scripts-dev/gen_config_documentation.py b/scripts-dev/gen_config_documentation.py index 9a49c07a34..aad25a4fc1 100755 --- a/scripts-dev/gen_config_documentation.py +++ b/scripts-dev/gen_config_documentation.py @@ -4,7 +4,7 @@ import json import re import sys -from typing import Any, Optional +from typing import Any import yaml @@ -259,17 +259,17 @@ def indent(text: str, first_line: bool = True) -> str: return text -def em(s: Optional[str]) -> str: +def em(s: str | None) -> str: """Add emphasis to text.""" return f"*{s}*" if s else "" -def a(s: Optional[str], suffix: str = " ") -> str: +def a(s: str | None, suffix: str = " ") -> str: """Appends a space if the given string is not empty.""" return s + suffix if s else "" -def p(s: Optional[str], prefix: str = " ") -> str: +def p(s: str | None, prefix: str = " ") -> str: """Prepend a space if the given string is not empty.""" return prefix + s if s else "" diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 830c4ac4ab..24794a1925 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -24,7 +24,7 @@ can crop up, e.g the cache descriptors. """ import enum -from typing import Callable, Mapping, Optional, Union +from typing import Callable, Mapping import attr import mypy.types @@ -123,7 +123,7 @@ class ArgLocation: """ -prometheus_metric_fullname_to_label_arg_map: Mapping[str, Optional[ArgLocation]] = { +prometheus_metric_fullname_to_label_arg_map: Mapping[str, ArgLocation | None] = { # `Collector` subclasses: "prometheus_client.metrics.MetricWrapperBase": ArgLocation("labelnames", 2), "prometheus_client.metrics.Counter": ArgLocation("labelnames", 2), @@ -211,7 +211,7 @@ class SynapsePlugin(Plugin): def get_base_class_hook( self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + ) -> Callable[[ClassDefContext], None] | None: def _get_base_class_hook(ctx: ClassDefContext) -> None: # Run any `get_base_class_hook` checks from other plugins first. # @@ -232,7 +232,7 @@ class SynapsePlugin(Plugin): def get_function_signature_hook( self, fullname: str - ) -> Optional[Callable[[FunctionSigContext], FunctionLike]]: + ) -> Callable[[FunctionSigContext], FunctionLike] | None: # Strip off the unique identifier for classes that are dynamically created inside # functions. ex. `synapse.metrics.jemalloc.JemallocCollector@185` (this is the line # number) @@ -262,7 +262,7 @@ class SynapsePlugin(Plugin): def get_method_signature_hook( self, fullname: str - ) -> Optional[Callable[[MethodSigContext], CallableType]]: + ) -> Callable[[MethodSigContext], CallableType] | None: if fullname.startswith( ( "synapse.util.caches.descriptors.CachedFunction.__call__", @@ -721,7 +721,7 @@ def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType: def check_is_cacheable( signature: CallableType, - ctx: Union[MethodSigContext, FunctionSigContext], + ctx: MethodSigContext | FunctionSigContext, ) -> None: """ Check if a callable returns a type which can be cached. @@ -795,7 +795,7 @@ AT_CACHED_MUTABLE_RETURN = ErrorCode( def is_cacheable( rt: mypy.types.Type, signature: CallableType, verbose: bool -) -> tuple[bool, Optional[str]]: +) -> tuple[bool, str | None]: """ Check if a particular type is cachable. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 262c1503c7..ba95a19382 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -32,7 +32,7 @@ import time import urllib.request from os import path from tempfile import TemporaryDirectory -from typing import Any, Match, Optional, Union +from typing import Any, Match import attr import click @@ -327,11 +327,11 @@ def _prepare() -> None: @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"]) -def tag(gh_token: Optional[str]) -> None: +def tag(gh_token: str | None) -> None: _tag(gh_token) -def _tag(gh_token: Optional[str]) -> None: +def _tag(gh_token: str | None) -> None: """Tags the release and generates a draft GitHub release""" # Test that the GH Token is valid before continuing. @@ -471,11 +471,11 @@ def _publish(gh_token: str) -> None: @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False) -def upload(gh_token: Optional[str]) -> None: +def upload(gh_token: str | None) -> None: _upload(gh_token) -def _upload(gh_token: Optional[str]) -> None: +def _upload(gh_token: str | None) -> None: """Upload release to pypi.""" # Test that the GH Token is valid before continuing. @@ -576,11 +576,11 @@ def _merge_into(repo: Repo, source: str, target: str) -> None: @cli.command() @click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False) -def wait_for_actions(gh_token: Optional[str]) -> None: +def wait_for_actions(gh_token: str | None) -> None: _wait_for_actions(gh_token) -def _wait_for_actions(gh_token: Optional[str]) -> None: +def _wait_for_actions(gh_token: str | None) -> None: # Test that the GH Token is valid before continuing. check_valid_gh_token(gh_token) @@ -658,7 +658,7 @@ def _notify(message: str) -> None: envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False, ) -def merge_back(_gh_token: Optional[str]) -> None: +def merge_back(_gh_token: str | None) -> None: _merge_back() @@ -715,7 +715,7 @@ def _merge_back() -> None: envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False, ) -def announce(_gh_token: Optional[str]) -> None: +def announce(_gh_token: str | None) -> None: _announce() @@ -851,7 +851,7 @@ def get_repo_and_check_clean_checkout( return repo -def check_valid_gh_token(gh_token: Optional[str]) -> None: +def check_valid_gh_token(gh_token: str | None) -> None: """Check that a github token is valid, if supplied""" if not gh_token: @@ -867,7 +867,7 @@ def check_valid_gh_token(gh_token: Optional[str]) -> None: raise click.ClickException(f"Github credentials are bad: {e}") -def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]: +def find_ref(repo: git.Repo, ref_name: str) -> git.HEAD | None: """Find the branch/ref, looking first locally then in the remote.""" if ref_name in repo.references: return repo.references[ref_name] @@ -904,7 +904,7 @@ def get_changes_for_version(wanted_version: version.Version) -> str: # These are 0-based. start_line: int - end_line: Optional[int] = None # Is none if its the last entry + end_line: int | None = None # Is none if its the last entry headings: list[VersionSection] = [] for i, token in enumerate(tokens): @@ -991,7 +991,7 @@ def build_dependabot_changelog(repo: Repo, current_version: version.Version) -> messages = [] for commit in reversed(commits): if commit.author.name == "dependabot[bot]": - message: Union[str, bytes] = commit.message + message: str | bytes = commit.message if isinstance(message, bytes): message = message.decode("utf-8") messages.append(message.split("\n", maxsplit=1)[0]) diff --git a/scripts-dev/schema_versions.py b/scripts-dev/schema_versions.py index cec58e177f..b3946ea7a1 100755 --- a/scripts-dev/schema_versions.py +++ b/scripts-dev/schema_versions.py @@ -38,7 +38,7 @@ import io import json import sys from collections import defaultdict -from typing import Any, Iterator, Optional +from typing import Any, Iterator import git from packaging import version @@ -57,7 +57,7 @@ SCHEMA_VERSION_FILES = ( OLDEST_SHOWN_VERSION = version.parse("v1.0") -def get_schema_versions(tag: git.Tag) -> tuple[Optional[int], Optional[int]]: +def get_schema_versions(tag: git.Tag) -> tuple[int | None, int | None]: """Get the schema and schema compat versions for a tag.""" schema_version = None schema_compat_version = None diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index a0be3e6349..8616f8d4f7 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -13,10 +13,8 @@ from typing import ( Iterator, KeysView, Mapping, - Optional, Sequence, TypeVar, - Union, ValuesView, overload, ) @@ -51,7 +49,7 @@ class SortedDict(dict[_KT, _VT]): self, __key: _Key[_KT], __iterable: Iterable[tuple[_KT, _VT]], **kwargs: _VT ) -> None: ... @property - def key(self) -> Optional[_Key[_KT]]: ... + def key(self) -> _Key[_KT] | None: ... @property def iloc(self) -> SortedKeysView[_KT]: ... def clear(self) -> None: ... @@ -79,10 +77,10 @@ class SortedDict(dict[_KT, _VT]): @overload def pop(self, key: _KT) -> _VT: ... @overload - def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ... + def pop(self, key: _KT, default: _T = ...) -> _VT | _T: ... def popitem(self, index: int = ...) -> tuple[_KT, _VT]: ... def peekitem(self, index: int = ...) -> tuple[_KT, _VT]: ... - def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ... + def setdefault(self, key: _KT, default: _VT | None = ...) -> _VT: ... # Mypy now reports the first overload as an error, because typeshed widened the type # of `__map` to its internal `_typeshed.SupportsKeysAndGetItem` type in # https://github.com/python/typeshed/pull/6653 @@ -106,8 +104,8 @@ class SortedDict(dict[_KT, _VT]): def _check(self) -> None: ... def islice( self, - start: Optional[int] = ..., - stop: Optional[int] = ..., + start: int | None = ..., + stop: int | None = ..., reverse: bool = ..., ) -> Iterator[_KT]: ... def bisect_left(self, value: _KT) -> int: ... @@ -118,7 +116,7 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]): def __getitem__(self, index: int) -> _KT_co: ... @overload def __getitem__(self, index: slice) -> list[_KT_co]: ... - def __delitem__(self, index: Union[int, slice]) -> None: ... + def __delitem__(self, index: int | slice) -> None: ... class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]]): def __iter__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ... @@ -126,11 +124,11 @@ class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]] def __getitem__(self, index: int) -> tuple[_KT_co, _VT_co]: ... @overload def __getitem__(self, index: slice) -> list[tuple[_KT_co, _VT_co]]: ... - def __delitem__(self, index: Union[int, slice]) -> None: ... + def __delitem__(self, index: int | slice) -> None: ... class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]): @overload def __getitem__(self, index: int) -> _VT_co: ... @overload def __getitem__(self, index: slice) -> list[_VT_co]: ... - def __delitem__(self, index: Union[int, slice]) -> None: ... + def __delitem__(self, index: int | slice) -> None: ... diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index 25ceb74cc9..f5e056111b 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -10,10 +10,8 @@ from typing import ( Iterable, Iterator, MutableSequence, - Optional, Sequence, TypeVar, - Union, overload, ) @@ -29,8 +27,8 @@ class SortedList(MutableSequence[_T]): DEFAULT_LOAD_FACTOR: int = ... def __init__( self, - iterable: Optional[Iterable[_T]] = ..., - key: Optional[_Key[_T]] = ..., + iterable: Iterable[_T] | None = ..., + key: _Key[_T] | None = ..., ): ... # NB: currently mypy does not honour return type, see mypy #3307 @overload @@ -42,7 +40,7 @@ class SortedList(MutableSequence[_T]): @overload def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ... @property - def key(self) -> Optional[Callable[[_T], Any]]: ... + def key(self) -> Callable[[_T], Any] | None: ... def _reset(self, load: int) -> None: ... def clear(self) -> None: ... def _clear(self) -> None: ... @@ -57,7 +55,7 @@ class SortedList(MutableSequence[_T]): def _pos(self, idx: int) -> int: ... def _build_index(self) -> None: ... def __contains__(self, value: Any) -> bool: ... - def __delitem__(self, index: Union[int, slice]) -> None: ... + def __delitem__(self, index: int | slice) -> None: ... @overload def __getitem__(self, index: int) -> _T: ... @overload @@ -76,8 +74,8 @@ class SortedList(MutableSequence[_T]): def reverse(self) -> None: ... def islice( self, - start: Optional[int] = ..., - stop: Optional[int] = ..., + start: int | None = ..., + stop: int | None = ..., reverse: bool = ..., ) -> Iterator[_T]: ... def _islice( @@ -90,8 +88,8 @@ class SortedList(MutableSequence[_T]): ) -> Iterator[_T]: ... def irange( self, - minimum: Optional[int] = ..., - maximum: Optional[int] = ..., + minimum: int | None = ..., + maximum: int | None = ..., inclusive: tuple[bool, bool] = ..., reverse: bool = ..., ) -> Iterator[_T]: ... @@ -107,7 +105,7 @@ class SortedList(MutableSequence[_T]): def insert(self, index: int, value: _T) -> None: ... def pop(self, index: int = ...) -> _T: ... def index( - self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + self, value: _T, start: int | None = ..., stop: int | None = ... ) -> int: ... def __add__(self: _SL, other: Iterable[_T]) -> _SL: ... def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ... @@ -126,10 +124,10 @@ class SortedList(MutableSequence[_T]): class SortedKeyList(SortedList[_T]): def __init__( - self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + self, iterable: Iterable[_T] | None = ..., key: _Key[_T] = ... ) -> None: ... def __new__( - cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + cls, iterable: Iterable[_T] | None = ..., key: _Key[_T] = ... ) -> SortedKeyList[_T]: ... @property def key(self) -> Callable[[_T], Any]: ... @@ -146,15 +144,15 @@ class SortedKeyList(SortedList[_T]): def _delete(self, pos: int, idx: int) -> None: ... def irange( self, - minimum: Optional[int] = ..., - maximum: Optional[int] = ..., + minimum: int | None = ..., + maximum: int | None = ..., inclusive: tuple[bool, bool] = ..., reverse: bool = ..., ) -> Iterator[_T]: ... def irange_key( self, - min_key: Optional[Any] = ..., - max_key: Optional[Any] = ..., + min_key: Any | None = ..., + max_key: Any | None = ..., inclusive: tuple[bool, bool] = ..., reserve: bool = ..., ) -> Iterator[_T]: ... @@ -170,7 +168,7 @@ class SortedKeyList(SortedList[_T]): def copy(self: _SKL) -> _SKL: ... def __copy__(self: _SKL) -> _SKL: ... def index( - self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + self, value: _T, start: int | None = ..., stop: int | None = ... ) -> int: ... def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ... def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... diff --git a/stubs/sortedcontainers/sortedset.pyi b/stubs/sortedcontainers/sortedset.pyi index a3593ca579..da2696b262 100644 --- a/stubs/sortedcontainers/sortedset.pyi +++ b/stubs/sortedcontainers/sortedset.pyi @@ -11,10 +11,8 @@ from typing import ( Iterable, Iterator, MutableSet, - Optional, Sequence, TypeVar, - Union, overload, ) @@ -28,21 +26,19 @@ _Key = Callable[[_T], Any] class SortedSet(MutableSet[_T], Sequence[_T]): def __init__( self, - iterable: Optional[Iterable[_T]] = ..., - key: Optional[_Key[_T]] = ..., + iterable: Iterable[_T] | None = ..., + key: _Key[_T] | None = ..., ) -> None: ... @classmethod - def _fromset( - cls, values: set[_T], key: Optional[_Key[_T]] = ... - ) -> SortedSet[_T]: ... + def _fromset(cls, values: set[_T], key: _Key[_T] | None = ...) -> SortedSet[_T]: ... @property - def key(self) -> Optional[_Key[_T]]: ... + def key(self) -> _Key[_T] | None: ... def __contains__(self, value: Any) -> bool: ... @overload def __getitem__(self, index: int) -> _T: ... @overload def __getitem__(self, index: slice) -> list[_T]: ... - def __delitem__(self, index: Union[int, slice]) -> None: ... + def __delitem__(self, index: int | slice) -> None: ... def __eq__(self, other: Any) -> bool: ... def __ne__(self, other: Any) -> bool: ... def __lt__(self, other: Iterable[_T]) -> bool: ... @@ -62,32 +58,28 @@ class SortedSet(MutableSet[_T], Sequence[_T]): def _discard(self, value: _T) -> None: ... def pop(self, index: int = ...) -> _T: ... def remove(self, value: _T) -> None: ... - def difference(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def difference_update( - self, *iterables: Iterable[_S] - ) -> SortedSet[Union[_T, _S]]: ... - def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def intersection(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __and__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def intersection_update( - self, *iterables: Iterable[_S] - ) -> SortedSet[Union[_T, _S]]: ... - def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __xor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __rxor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def difference(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def difference_update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def intersection(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __and__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def intersection_update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __xor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __rxor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ... def symmetric_difference_update( self, other: Iterable[_S] - ) -> SortedSet[Union[_T, _S]]: ... - def __ixor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def union(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __or__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... - def _update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + ) -> SortedSet[_T | _S]: ... + def __ixor__(self, other: Iterable[_S]) -> SortedSet[_T | _S]: ... + def union(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __or__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... + def _update(self, *iterables: Iterable[_S]) -> SortedSet[_T | _S]: ... def __reduce__( self, ) -> tuple[type[SortedSet[_T]], set[_T], Callable[[_T], Any]]: ... @@ -97,18 +89,18 @@ class SortedSet(MutableSet[_T], Sequence[_T]): def bisect_right(self, value: _T) -> int: ... def islice( self, - start: Optional[int] = ..., - stop: Optional[int] = ..., + start: int | None = ..., + stop: int | None = ..., reverse: bool = ..., ) -> Iterator[_T]: ... def irange( self, - minimum: Optional[_T] = ..., - maximum: Optional[_T] = ..., + minimum: _T | None = ..., + maximum: _T | None = ..., inclusive: tuple[bool, bool] = ..., reverse: bool = ..., ) -> Iterator[_T]: ... def index( - self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + self, value: _T, start: int | None = ..., stop: int | None = ... ) -> int: ... def _reset(self, load: int) -> None: ... diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index d2539aa37d..50ab54037a 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,7 +15,7 @@ """Contains *incomplete* type hints for txredisapi.""" -from typing import Any, Optional, Union +from typing import Any from twisted.internet import protocol from twisted.internet.defer import Deferred @@ -29,8 +29,8 @@ class RedisProtocol(protocol.Protocol): self, key: str, value: Any, - expire: Optional[int] = None, - pexpire: Optional[int] = None, + expire: int | None = None, + pexpire: int | None = None, only_if_not_exists: bool = False, only_if_exists: bool = False, ) -> "Deferred[None]": ... @@ -38,8 +38,8 @@ class RedisProtocol(protocol.Protocol): class SubscriberProtocol(RedisProtocol): def __init__(self, *args: object, **kwargs: object): ... - password: Optional[str] - def subscribe(self, channels: Union[str, list[str]]) -> "Deferred[None]": ... + password: str | None + def subscribe(self, channels: str | list[str]) -> "Deferred[None]": ... def connectionMade(self) -> None: ... # type-ignore: twisted.internet.protocol.Protocol provides a default argument for # `reason`. txredisapi's LineReceiver Protocol doesn't. But that's fine: it's what's @@ -49,12 +49,12 @@ class SubscriberProtocol(RedisProtocol): def lazyConnection( host: str = ..., port: int = ..., - dbid: Optional[int] = ..., + dbid: int | None = ..., reconnect: bool = ..., charset: str = ..., - password: Optional[str] = ..., - connectTimeout: Optional[int] = ..., - replyTimeout: Optional[int] = ..., + password: str | None = ..., + connectTimeout: int | None = ..., + replyTimeout: int | None = ..., convertNumbers: bool = ..., ) -> RedisProtocol: ... @@ -70,18 +70,18 @@ class RedisFactory(protocol.ReconnectingClientFactory): continueTrying: bool handler: ConnectionHandler pool: list[RedisProtocol] - replyTimeout: Optional[int] + replyTimeout: int | None def __init__( self, uuid: str, - dbid: Optional[int], + dbid: int | None, poolsize: int, isLazy: bool = False, handler: type = ConnectionHandler, charset: str = "utf-8", - password: Optional[str] = None, - replyTimeout: Optional[int] = None, - convertNumbers: Optional[int] = True, + password: str | None = None, + replyTimeout: int | None = None, + convertNumbers: int | None = True, ): ... def buildProtocol(self, addr: IAddress) -> RedisProtocol: ... diff --git a/synapse/_scripts/export_signing_key.py b/synapse/_scripts/export_signing_key.py index 690115aabe..bab5953802 100755 --- a/synapse/_scripts/export_signing_key.py +++ b/synapse/_scripts/export_signing_key.py @@ -22,13 +22,13 @@ import argparse import sys import time -from typing import NoReturn, Optional +from typing import NoReturn from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys from signedjson.types import VerifyKey -def exit(status: int = 0, message: Optional[str] = None) -> NoReturn: +def exit(status: int = 0, message: str | None = None) -> NoReturn: if message: print(message, file=sys.stderr) sys.exit(status) diff --git a/synapse/_scripts/generate_workers_map.py b/synapse/_scripts/generate_workers_map.py index e669f6902d..3fa27b4b2a 100755 --- a/synapse/_scripts/generate_workers_map.py +++ b/synapse/_scripts/generate_workers_map.py @@ -25,7 +25,7 @@ import logging import re from collections import defaultdict from dataclasses import dataclass -from typing import Iterable, Optional, Pattern +from typing import Iterable, Pattern import yaml @@ -46,7 +46,7 @@ logger = logging.getLogger("generate_workers_map") class MockHomeserver(HomeServer): DATASTORE_CLASS = DataStore - def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None: + def __init__(self, config: HomeServerConfig, worker_app: str | None) -> None: super().__init__(config.server.server_name, config=config) self.config.worker.worker_app = worker_app @@ -65,7 +65,7 @@ class EndpointDescription: # The category of this endpoint. Is read from the `CATEGORY` constant in the servlet # class. - category: Optional[str] + category: str | None # TODO: # - does it need to be routed based on a stream writer config? @@ -141,7 +141,7 @@ def get_registered_paths_for_hs( def get_registered_paths_for_default( - worker_app: Optional[str], base_config: HomeServerConfig + worker_app: str | None, base_config: HomeServerConfig ) -> dict[tuple[str, str], EndpointDescription]: """ Given the name of a worker application and a base homeserver configuration, @@ -271,7 +271,7 @@ def main() -> None: # TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT categories_to_methods_and_paths: dict[ - Optional[str], dict[tuple[str, str], EndpointDescription] + str | None, dict[tuple[str, str], EndpointDescription] ] = defaultdict(dict) for (method, path), desc in elided_worker_paths.items(): @@ -282,7 +282,7 @@ def main() -> None: def print_category( - category_name: Optional[str], + category_name: str | None, elided_worker_paths: dict[tuple[str, str], EndpointDescription], ) -> None: """ diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 3fe2f33e52..1ce8221218 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -26,7 +26,7 @@ import hashlib import hmac import logging import sys -from typing import Any, Callable, Optional +from typing import Any, Callable import requests import yaml @@ -54,7 +54,7 @@ def request_registration( server_location: str, shared_secret: str, admin: bool = False, - user_type: Optional[str] = None, + user_type: str | None = None, _print: Callable[[str], None] = print, exit: Callable[[int], None] = sys.exit, exists_ok: bool = False, @@ -123,13 +123,13 @@ def register_new_user( password: str, server_location: str, shared_secret: str, - admin: Optional[bool], - user_type: Optional[str], + admin: bool | None, + user_type: str | None, exists_ok: bool = False, ) -> None: if not user: try: - default_user: Optional[str] = getpass.getuser() + default_user: str | None = getpass.getuser() except Exception: default_user = None @@ -262,7 +262,7 @@ def main() -> None: args = parser.parse_args() - config: Optional[dict[str, Any]] = None + config: dict[str, Any] | None = None if "config" in args and args.config: config = yaml.safe_load(args.config) @@ -350,7 +350,7 @@ def _read_file(file_path: Any, config_path: str) -> str: sys.exit(1) -def _find_client_listener(config: dict[str, Any]) -> Optional[str]: +def _find_client_listener(config: dict[str, Any]) -> str | None: # try to find a listener in the config. Returns a host:port pair for listener in config.get("listeners", []): if listener.get("type") != "http" or listener.get("tls", False): diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index e83c0de5a4..1806e42d90 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -233,14 +233,14 @@ IGNORED_BACKGROUND_UPDATES = { # Error returned by the run function. Used at the top-level part of the script to # handle errors and return codes. -end_error: Optional[str] = None +end_error: str | None = None # The exec_info for the error, if any. If error is defined but not exec_info the script # will show only the error message without the stacktrace, if exec_info is defined but # not the error then the script will show nothing outside of what's printed in the run # function. If both are defined, the script will print both the error and the stacktrace. -end_error_exec_info: Optional[ - tuple[type[BaseException], BaseException, TracebackType] -] = None +end_error_exec_info: tuple[type[BaseException], BaseException, TracebackType] | None = ( + None +) R = TypeVar("R") @@ -485,7 +485,7 @@ class Porter: def r( txn: LoggingTransaction, - ) -> tuple[Optional[list[str]], list[tuple], list[tuple]]: + ) -> tuple[list[str] | None, list[tuple], list[tuple]]: forward_rows = [] backward_rows = [] if do_forward[0]: @@ -502,7 +502,7 @@ class Porter: if forward_rows or backward_rows: assert txn.description is not None - headers: Optional[list[str]] = [ + headers: list[str] | None = [ column[0] for column in txn.description ] else: @@ -1152,9 +1152,7 @@ class Porter: return done, remaining + done async def _setup_state_group_id_seq(self) -> None: - curr_id: Optional[ - int - ] = await self.sqlite_store.db_pool.simple_select_one_onecol( + curr_id: int | None = await self.sqlite_store.db_pool.simple_select_one_onecol( table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True ) @@ -1271,10 +1269,10 @@ class Porter: await self.postgres_store.db_pool.runInteraction("_setup_%s" % (seq_name,), r) - async def _pg_get_serial_sequence(self, table: str, column: str) -> Optional[str]: + async def _pg_get_serial_sequence(self, table: str, column: str) -> str | None: """Returns the name of the postgres sequence associated with a column, or NULL.""" - def r(txn: LoggingTransaction) -> Optional[str]: + def r(txn: LoggingTransaction) -> str | None: txn.execute("SELECT pg_get_serial_sequence('%s', '%s')" % (table, column)) result = txn.fetchone() if not result: @@ -1286,9 +1284,9 @@ class Porter: ) async def _setup_auth_chain_sequence(self) -> None: - curr_chain_id: Optional[ - int - ] = await self.sqlite_store.db_pool.simple_select_one_onecol( + curr_chain_id: ( + int | None + ) = await self.sqlite_store.db_pool.simple_select_one_onecol( table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py index 2e2aa27a17..29ab955c45 100755 --- a/synapse/_scripts/synctl.py +++ b/synapse/_scripts/synctl.py @@ -30,7 +30,7 @@ import signal import subprocess import sys import time -from typing import Iterable, NoReturn, Optional, TextIO +from typing import Iterable, NoReturn, TextIO import yaml @@ -135,7 +135,7 @@ def start(pidfile: str, app: str, config_files: Iterable[str], daemonize: bool) return False -def stop(pidfile: str, app: str) -> Optional[int]: +def stop(pidfile: str, app: str) -> int | None: """Attempts to kill a synapse worker from the pidfile. Args: pidfile: path to file containing worker's pid diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py index cc0c0d4601..201c295f06 100644 --- a/synapse/api/auth/__init__.py +++ b/synapse/api/auth/__init__.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Optional, Protocol +from typing import TYPE_CHECKING, Protocol from prometheus_client import Histogram @@ -51,7 +51,7 @@ class Auth(Protocol): room_id: str, requester: Requester, allow_departed_users: bool = False, - ) -> tuple[str, Optional[str]]: + ) -> tuple[str, str | None]: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -190,7 +190,7 @@ class Auth(Protocol): async def check_user_in_room_or_world_readable( self, room_id: str, requester: Requester, allow_departed_users: bool = False - ) -> tuple[str, Optional[str]]: + ) -> tuple[str, str | None]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. diff --git a/synapse/api/auth/base.py b/synapse/api/auth/base.py index d5635e588f..ff876b9d22 100644 --- a/synapse/api/auth/base.py +++ b/synapse/api/auth/base.py @@ -19,7 +19,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from netaddr import IPAddress @@ -64,7 +64,7 @@ class BaseAuth: room_id: str, requester: Requester, allow_departed_users: bool = False, - ) -> tuple[str, Optional[str]]: + ) -> tuple[str, str | None]: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -114,7 +114,7 @@ class BaseAuth: @trace async def check_user_in_room_or_world_readable( self, room_id: str, requester: Requester, allow_departed_users: bool = False - ) -> tuple[str, Optional[str]]: + ) -> tuple[str, str | None]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. @@ -294,7 +294,7 @@ class BaseAuth: @cancellable async def get_appservice_user( self, request: Request, access_token: str - ) -> Optional[Requester]: + ) -> Requester | None: """ Given a request, reads the request parameters to determine: - whether it's an application service that's making this request diff --git a/synapse/api/auth/mas.py b/synapse/api/auth/mas.py index f2b218e34f..e422a1e5c5 100644 --- a/synapse/api/auth/mas.py +++ b/synapse/api/auth/mas.py @@ -13,7 +13,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from urllib.parse import urlencode from pydantic import ( @@ -74,11 +74,11 @@ class ServerMetadata(BaseModel): class IntrospectionResponse(BaseModel): retrieved_at_ms: StrictInt active: StrictBool - scope: Optional[StrictStr] = None - username: Optional[StrictStr] = None - sub: Optional[StrictStr] = None - device_id: Optional[StrictStr] = None - expires_in: Optional[StrictInt] = None + scope: StrictStr | None = None + username: StrictStr | None = None + sub: StrictStr | None = None + device_id: StrictStr | None = None + expires_in: StrictInt | None = None model_config = ConfigDict(extra="allow") def get_scope_set(self) -> set[str]: diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index 48b32aa04a..7999d6e459 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -20,7 +20,7 @@ # import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable from urllib.parse import urlencode from authlib.oauth2 import ClientAuth @@ -102,25 +102,25 @@ class IntrospectionResult: return [] return scope_to_list(value) - def get_sub(self) -> Optional[str]: + def get_sub(self) -> str | None: value = self._inner.get("sub") if not isinstance(value, str): return None return value - def get_username(self) -> Optional[str]: + def get_username(self) -> str | None: value = self._inner.get("username") if not isinstance(value, str): return None return value - def get_name(self) -> Optional[str]: + def get_name(self) -> str | None: value = self._inner.get("name") if not isinstance(value, str): return None return value - def get_device_id(self) -> Optional[str]: + def get_device_id(self) -> str | None: value = self._inner.get("device_id") if value is not None and not isinstance(value, str): raise AuthError( @@ -174,7 +174,7 @@ class MSC3861DelegatedAuth(BaseAuth): self._clock = hs.get_clock() self._http_client = hs.get_proxied_http_client() self._hostname = hs.hostname - self._admin_token: Callable[[], Optional[str]] = self._config.admin_token + self._admin_token: Callable[[], str | None] = self._config.admin_token self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users self._rust_http_client = HttpClient( @@ -247,7 +247,7 @@ class MSC3861DelegatedAuth(BaseAuth): metadata = await self._issuer_metadata.get() return metadata.issuer or self._config.issuer - async def account_management_url(self) -> Optional[str]: + async def account_management_url(self) -> str | None: """ Get the configured account management URL diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 303c9ba03e..3ed47b20c4 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError @@ -51,10 +51,10 @@ class AuthBlocking: async def check_auth_blocking( self, - user_id: Optional[str] = None, - threepid: Optional[dict] = None, - user_type: Optional[str] = None, - requester: Optional[Requester] = None, + user_id: str | None = None, + threepid: dict | None = None, + user_type: str | None = None, + requester: Requester | None = None, ) -> None: """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag diff --git a/synapse/api/errors.py b/synapse/api/errors.py index f75b34ef69..c4339ebef8 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -26,7 +26,7 @@ import math import typing from enum import Enum from http import HTTPStatus -from typing import Any, Optional, Union +from typing import Any, Optional from twisted.web import http @@ -164,9 +164,9 @@ class CodeMessageException(RuntimeError): def __init__( self, - code: Union[int, HTTPStatus], + code: int | HTTPStatus, msg: str, - headers: Optional[dict[str, str]] = None, + headers: dict[str, str] | None = None, ): super().__init__("%d: %s" % (code, msg)) @@ -223,8 +223,8 @@ class SynapseError(CodeMessageException): code: int, msg: str, errcode: str = Codes.UNKNOWN, - additional_fields: Optional[dict] = None, - headers: Optional[dict[str, str]] = None, + additional_fields: dict | None = None, + headers: dict[str, str] | None = None, ): """Constructs a synapse error. @@ -244,7 +244,7 @@ class SynapseError(CodeMessageException): return cs_error(self.msg, self.errcode, **self._additional_fields) @property - def debug_context(self) -> Optional[str]: + def debug_context(self) -> str | None: """Override this to add debugging context that shouldn't be sent to clients.""" return None @@ -276,7 +276,7 @@ class ProxiedRequestError(SynapseError): code: int, msg: str, errcode: str = Codes.UNKNOWN, - additional_fields: Optional[dict] = None, + additional_fields: dict | None = None, ): super().__init__(code, msg, errcode, additional_fields) @@ -340,7 +340,7 @@ class FederationDeniedError(SynapseError): destination: The destination which has been denied """ - def __init__(self, destination: Optional[str]): + def __init__(self, destination: str | None): """Raised by federation client or server to indicate that we are are deliberately not attempting to contact a given server because it is not on our federation whitelist. @@ -399,7 +399,7 @@ class AuthError(SynapseError): code: int, msg: str, errcode: str = Codes.FORBIDDEN, - additional_fields: Optional[dict] = None, + additional_fields: dict | None = None, ): super().__init__(code, msg, errcode, additional_fields) @@ -432,7 +432,7 @@ class UnstableSpecAuthError(AuthError): msg: str, errcode: str, previous_errcode: str = Codes.FORBIDDEN, - additional_fields: Optional[dict] = None, + additional_fields: dict | None = None, ): self.previous_errcode = previous_errcode super().__init__(code, msg, errcode, additional_fields) @@ -497,8 +497,8 @@ class ResourceLimitError(SynapseError): code: int, msg: str, errcode: str = Codes.RESOURCE_LIMIT_EXCEEDED, - admin_contact: Optional[str] = None, - limit_type: Optional[str] = None, + admin_contact: str | None = None, + limit_type: str | None = None, ): self.admin_contact = admin_contact self.limit_type = limit_type @@ -542,7 +542,7 @@ class InvalidCaptchaError(SynapseError): self, code: int = 400, msg: str = "Invalid captcha.", - error_url: Optional[str] = None, + error_url: str | None = None, errcode: str = Codes.CAPTCHA_INVALID, ): super().__init__(code, msg, errcode) @@ -563,9 +563,9 @@ class LimitExceededError(SynapseError): self, limiter_name: str, code: int = 429, - retry_after_ms: Optional[int] = None, + retry_after_ms: int | None = None, errcode: str = Codes.LIMIT_EXCEEDED, - pause: Optional[float] = None, + pause: float | None = None, ): # Use HTTP header Retry-After to enable library-assisted retry handling. headers = ( @@ -582,7 +582,7 @@ class LimitExceededError(SynapseError): return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) @property - def debug_context(self) -> Optional[str]: + def debug_context(self) -> str | None: return self.limiter_name @@ -675,7 +675,7 @@ class RequestSendFailed(RuntimeError): class UnredactedContentDeletedError(SynapseError): - def __init__(self, content_keep_ms: Optional[int] = None): + def __init__(self, content_keep_ms: int | None = None): super().__init__( 404, "The content for that event has already been erased from the database", @@ -751,7 +751,7 @@ class FederationError(RuntimeError): code: int, reason: str, affected: str, - source: Optional[str] = None, + source: str | None = None, ): if level not in ["FATAL", "ERROR", "WARN"]: raise ValueError("Level is not valid: %s" % (level,)) @@ -786,7 +786,7 @@ class FederationPullAttemptBackoffError(RuntimeError): """ def __init__( - self, event_ids: "StrCollection", message: Optional[str], retry_after_ms: int + self, event_ids: "StrCollection", message: str | None, retry_after_ms: int ): event_ids = list(event_ids) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index e31bec1a00..9b47c20437 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -28,9 +28,7 @@ from typing import ( Collection, Iterable, Mapping, - Optional, TypeVar, - Union, ) import jsonschema @@ -155,7 +153,7 @@ class Filtering: self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) async def get_user_filter( - self, user_id: UserID, filter_id: Union[int, str] + self, user_id: UserID, filter_id: int | str ) -> "FilterCollection": result = await self.store.get_user_filter(user_id, filter_id) return FilterCollection(self._hs, result) @@ -531,7 +529,7 @@ class Filter: return newFilter -def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool: +def _matches_wildcard(actual_value: str | None, filter_value: str) -> bool: if filter_value.endswith("*") and isinstance(actual_value, str): type_prefix = filter_value[:-1] return actual_value.startswith(type_prefix) diff --git a/synapse/api/presence.py b/synapse/api/presence.py index 28c10403ce..0e2fe625c9 100644 --- a/synapse/api/presence.py +++ b/synapse/api/presence.py @@ -19,7 +19,7 @@ # # -from typing import Any, Optional +from typing import Any import attr @@ -41,15 +41,13 @@ class UserDevicePresenceState: """ user_id: str - device_id: Optional[str] + device_id: str | None state: str last_active_ts: int last_sync_ts: int @classmethod - def default( - cls, user_id: str, device_id: Optional[str] - ) -> "UserDevicePresenceState": + def default(cls, user_id: str, device_id: str | None) -> "UserDevicePresenceState": """Returns a default presence state.""" return cls( user_id=user_id, @@ -81,7 +79,7 @@ class UserPresenceState: last_active_ts: int last_federation_update_ts: int last_user_sync_ts: int - status_msg: Optional[str] + status_msg: str | None currently_active: bool def as_dict(self) -> JsonDict: diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index ee0e9181ce..df884d47d7 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -102,9 +102,7 @@ class Ratelimiter: self.clock.looping_call(self._prune_message_counts, 15 * 1000) - def _get_key( - self, requester: Optional[Requester], key: Optional[Hashable] - ) -> Hashable: + def _get_key(self, requester: Requester | None, key: Hashable | None) -> Hashable: """Use the requester's MXID as a fallback key if no key is provided.""" if key is None: if not requester: @@ -121,13 +119,13 @@ class Ratelimiter: async def can_do_action( self, - requester: Optional[Requester], - key: Optional[Hashable] = None, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, + requester: Requester | None, + key: Hashable | None = None, + rate_hz: float | None = None, + burst_count: int | None = None, update: bool = True, n_actions: int = 1, - _time_now_s: Optional[float] = None, + _time_now_s: float | None = None, ) -> tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? @@ -247,10 +245,10 @@ class Ratelimiter: def record_action( self, - requester: Optional[Requester], - key: Optional[Hashable] = None, + requester: Requester | None, + key: Hashable | None = None, n_actions: int = 1, - _time_now_s: Optional[float] = None, + _time_now_s: float | None = None, ) -> None: """Record that an action(s) took place, even if they violate the rate limit. @@ -332,14 +330,14 @@ class Ratelimiter: async def ratelimit( self, - requester: Optional[Requester], - key: Optional[Hashable] = None, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, + requester: Requester | None, + key: Hashable | None = None, + rate_hz: float | None = None, + burst_count: int | None = None, update: bool = True, n_actions: int = 1, - _time_now_s: Optional[float] = None, - pause: Optional[float] = 0.5, + _time_now_s: float | None = None, + pause: float | None = 0.5, ) -> None: """Checks if an action can be performed. If not, raises a LimitExceededError @@ -396,7 +394,7 @@ class RequestRatelimiter: store: DataStore, clock: Clock, rc_message: RatelimitSettings, - rc_admin_redaction: Optional[RatelimitSettings], + rc_admin_redaction: RatelimitSettings | None, ): self.store = store self.clock = clock @@ -412,7 +410,7 @@ class RequestRatelimiter: # Check whether ratelimiting room admin message redaction is enabled # by the presence of rate limits in the config if rc_admin_redaction: - self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( + self.admin_redaction_ratelimiter: Ratelimiter | None = Ratelimiter( store=self.store, clock=self.clock, cfg=rc_admin_redaction, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index b6e76379f1..97dac661a3 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -18,7 +18,7 @@ # # -from typing import Callable, Optional +from typing import Callable import attr @@ -503,7 +503,7 @@ class RoomVersionCapability: """An object which describes the unique attributes of a room version.""" identifier: str # the identifier for this capability - preferred_version: Optional[RoomVersion] + preferred_version: RoomVersion | None support_check_lambda: Callable[[RoomVersion], bool] diff --git a/synapse/api/urls.py b/synapse/api/urls.py index baa6e2d390..b6147353d4 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -24,7 +24,6 @@ import hmac import urllib.parse from hashlib import sha256 -from typing import Optional from urllib.parse import urlencode, urljoin from synapse.config import ConfigError @@ -75,7 +74,7 @@ class LoginSSORedirectURIBuilder: self._public_baseurl = hs_config.server.public_baseurl def build_login_sso_redirect_uri( - self, *, idp_id: Optional[str], client_redirect_url: str + self, *, idp_id: str | None, client_redirect_url: str ) -> str: """Build a `/login/sso/redirect` URI for the given identity provider. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2de5bdb51e..52bdb9e0d7 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -36,8 +36,6 @@ from typing import ( Awaitable, Callable, NoReturn, - Optional, - Union, cast, ) from wsgiref.simple_server import WSGIServer @@ -180,8 +178,8 @@ def start_worker_reactor( def start_reactor( appname: str, soft_file_limit: int, - gc_thresholds: Optional[tuple[int, int, int]], - pid_file: Optional[str], + gc_thresholds: tuple[int, int, int] | None, + pid_file: str | None, daemonize: bool, print_pidfile: bool, logger: logging.Logger, @@ -421,7 +419,7 @@ def listen_http( root_resource: Resource, version_string: str, max_request_body_size: int, - context_factory: Optional[IOpenSSLContextFactory], + context_factory: IOpenSSLContextFactory | None, reactor: ISynapseReactor = reactor, ) -> list[Port]: """ @@ -564,9 +562,7 @@ def setup_sighup_handling() -> None: if _already_setup_sighup_handling: return - previous_sighup_handler: Union[ - Callable[[int, Optional[FrameType]], Any], int, None - ] = None + previous_sighup_handler: Callable[[int, FrameType | None], Any] | int | None = None # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 193482b7fc..facc98164e 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -24,7 +24,7 @@ import logging import os import sys import tempfile -from typing import Mapping, Optional, Sequence +from typing import Mapping, Sequence from twisted.internet import defer, task @@ -136,7 +136,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): to a temporary directory. """ - def __init__(self, user_id: str, directory: Optional[str] = None): + def __init__(self, user_id: str, directory: str | None = None): self.user_id = user_id if directory: @@ -291,7 +291,7 @@ def load_config(argv_options: list[str]) -> tuple[HomeServerConfig, argparse.Nam def create_homeserver( config: HomeServerConfig, - reactor: Optional[ISynapseReactor] = None, + reactor: ISynapseReactor | None = None, ) -> AdminCmdServer: """ Create a homeserver instance for the Synapse admin command process. diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index 73e33d77a5..dcb45e234b 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -26,7 +26,7 @@ import os import signal import sys from types import FrameType -from typing import Any, Callable, Optional +from typing import Any, Callable from twisted.internet.main import installReactor @@ -172,7 +172,7 @@ def main() -> None: # Install signal handlers to propagate signals to all our children, so that they # shut down cleanly. This also inhibits our own exit, but that's good: we want to # wait until the children have exited. - def handle_signal(signum: int, frame: Optional[FrameType]) -> None: + def handle_signal(signum: int, frame: FrameType | None) -> None: print( f"complement_fork_starter: Caught signal {signum}. Stopping children.", file=sys.stderr, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 0a4abd1839..9939c0fe7d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -21,7 +21,6 @@ # import logging import sys -from typing import Optional from twisted.web.resource import Resource @@ -336,7 +335,7 @@ def load_config(argv_options: list[str]) -> HomeServerConfig: def create_homeserver( config: HomeServerConfig, - reactor: Optional[ISynapseReactor] = None, + reactor: ISynapseReactor | None = None, ) -> GenericWorkerServer: """ Create a homeserver instance for the Synapse worker process. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bd51aad9ab..8fb906cdf7 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -22,7 +22,7 @@ import logging import os import sys -from typing import Iterable, Optional +from typing import Iterable from twisted.internet.tcp import Port from twisted.web.resource import EncodingResourceWrapper, Resource @@ -350,7 +350,7 @@ def load_or_generate_config(argv_options: list[str]) -> HomeServerConfig: def create_homeserver( config: HomeServerConfig, - reactor: Optional[ISynapseReactor] = None, + reactor: ISynapseReactor | None = None, ) -> SynapseHomeServer: """ Create a homeserver instance for the Synapse main process. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index e91fa3a624..620aa29dfc 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -26,7 +26,6 @@ from enum import Enum from typing import ( TYPE_CHECKING, Iterable, - Optional, Pattern, Sequence, cast, @@ -95,12 +94,12 @@ class ApplicationService: token: str, id: str, sender: UserID, - url: Optional[str] = None, - namespaces: Optional[JsonDict] = None, - hs_token: Optional[str] = None, - protocols: Optional[Iterable[str]] = None, + url: str | None = None, + namespaces: JsonDict | None = None, + hs_token: str | None = None, + protocols: Iterable[str] | None = None, rate_limited: bool = True, - ip_range_whitelist: Optional[IPSet] = None, + ip_range_whitelist: IPSet | None = None, supports_ephemeral: bool = False, msc3202_transaction_extensions: bool = False, msc4190_device_management: bool = False, @@ -142,7 +141,7 @@ class ApplicationService: self.rate_limited = rate_limited def _check_namespaces( - self, namespaces: Optional[JsonDict] + self, namespaces: JsonDict | None ) -> dict[str, list[Namespace]]: # Sanity check that it is of the form: # { @@ -179,9 +178,7 @@ class ApplicationService: return result - def _matches_regex( - self, namespace_key: str, test_string: str - ) -> Optional[Namespace]: + def _matches_regex(self, namespace_key: str, test_string: str) -> Namespace | None: for namespace in self.namespaces[namespace_key]: if namespace.regex.match(test_string): return namespace diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index f08a921998..71094de9be 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -25,10 +25,8 @@ from typing import ( TYPE_CHECKING, Iterable, Mapping, - Optional, Sequence, TypeVar, - Union, ) from prometheus_client import Counter @@ -222,7 +220,7 @@ class ApplicationServiceApi(SimpleHttpClient): assert service.hs_token is not None try: - args: Mapping[bytes, Union[list[bytes], str]] = fields + args: Mapping[bytes, list[bytes] | str] = fields if self.config.use_appservice_legacy_authorization: args = { **fields, @@ -258,11 +256,11 @@ class ApplicationServiceApi(SimpleHttpClient): async def get_3pe_protocol( self, service: "ApplicationService", protocol: str - ) -> Optional[JsonDict]: + ) -> JsonDict | None: if service.url is None: return {} - async def _get() -> Optional[JsonDict]: + async def _get() -> JsonDict | None: # This is required by the configuration. assert service.hs_token is not None try: @@ -300,7 +298,7 @@ class ApplicationServiceApi(SimpleHttpClient): key = (service.id, protocol) return await self.protocol_meta_cache.wrap(key, _get) - async def ping(self, service: "ApplicationService", txn_id: Optional[str]) -> None: + async def ping(self, service: "ApplicationService", txn_id: str | None) -> None: # The caller should check that url is set assert service.url is not None, "ping called without URL being set" @@ -322,7 +320,7 @@ class ApplicationServiceApi(SimpleHttpClient): one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, - txn_id: Optional[int] = None, + txn_id: int | None = None, ) -> bool: """ Push data to an application service. diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index b5fab5f50d..30c22780bd 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -62,7 +62,6 @@ from typing import ( Callable, Collection, Iterable, - Optional, Sequence, ) @@ -123,10 +122,10 @@ class ApplicationServiceScheduler: def enqueue_for_appservice( self, appservice: ApplicationService, - events: Optional[Collection[EventBase]] = None, - ephemeral: Optional[Collection[JsonMapping]] = None, - to_device_messages: Optional[Collection[JsonMapping]] = None, - device_list_summary: Optional[DeviceListUpdates] = None, + events: Collection[EventBase] | None = None, + ephemeral: Collection[JsonMapping] | None = None, + to_device_messages: Collection[JsonMapping] | None = None, + device_list_summary: DeviceListUpdates | None = None, ) -> None: """ Enqueue some data to be sent off to an application service. @@ -260,8 +259,8 @@ class _ServiceQueuer: ): return - one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None - unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None + one_time_keys_count: TransactionOneTimeKeysCount | None = None + unused_fallback_keys: TransactionUnusedFallbackKeys | None = None if ( self._msc3202_transaction_extensions_enabled @@ -369,11 +368,11 @@ class _TransactionController: self, service: ApplicationService, events: Sequence[EventBase], - ephemeral: Optional[list[JsonMapping]] = None, - to_device_messages: Optional[list[JsonMapping]] = None, - one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, - unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, - device_list_summary: Optional[DeviceListUpdates] = None, + ephemeral: list[JsonMapping] | None = None, + to_device_messages: list[JsonMapping] | None = None, + one_time_keys_count: TransactionOneTimeKeysCount | None = None, + unused_fallback_keys: TransactionUnusedFallbackKeys | None = None, + device_list_summary: DeviceListUpdates | None = None, ) -> None: """ Create a transaction with the given data and send to the provided @@ -504,7 +503,7 @@ class _Recoverer: self.service = service self.callback = callback self.backoff_counter = 1 - self.scheduled_recovery: Optional[IDelayedCall] = None + self.scheduled_recovery: IDelayedCall | None = None def recover(self) -> None: delay = 2**self.backoff_counter diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ce06905390..95a00c6718 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -36,9 +36,7 @@ from typing import ( Iterable, Iterator, MutableMapping, - Optional, TypeVar, - Union, ) import attr @@ -60,7 +58,7 @@ class ConfigError(Exception): the problem lies. """ - def __init__(self, msg: str, path: Optional[StrSequence] = None): + def __init__(self, msg: str, path: StrSequence | None = None): self.msg = msg self.path = path @@ -175,7 +173,7 @@ class Config: ) @staticmethod - def parse_size(value: Union[str, int]) -> int: + def parse_size(value: str | int) -> int: """Interpret `value` as a number of bytes. If an integer is provided it is treated as bytes and is unchanged. @@ -202,7 +200,7 @@ class Config: raise TypeError(f"Bad byte size {value!r}") @staticmethod - def parse_duration(value: Union[str, int]) -> int: + def parse_duration(value: str | int) -> int: """Convert a duration as a string or integer to a number of milliseconds. If an integer is provided it is treated as milliseconds and is unchanged. @@ -270,7 +268,7 @@ class Config: return path_exists(file_path) @classmethod - def check_file(cls, file_path: Optional[str], config_name: str) -> str: + def check_file(cls, file_path: str | None, config_name: str) -> str: if file_path is None: raise ConfigError("Missing config for %s." % (config_name,)) try: @@ -318,7 +316,7 @@ class Config: def read_templates( self, filenames: list[str], - custom_template_directories: Optional[Iterable[str]] = None, + custom_template_directories: Iterable[str] | None = None, ) -> list[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -465,11 +463,11 @@ class RootConfig: data_dir_path: str, server_name: str, generate_secrets: bool = False, - report_stats: Optional[bool] = None, + report_stats: bool | None = None, open_private_ports: bool = False, - listeners: Optional[list[dict]] = None, - tls_certificate_path: Optional[str] = None, - tls_private_key_path: Optional[str] = None, + listeners: list[dict] | None = None, + tls_certificate_path: str | None = None, + tls_private_key_path: str | None = None, ) -> str: """ Build a default configuration file @@ -655,7 +653,7 @@ class RootConfig: @classmethod def load_or_generate_config( cls: type[TRootConfig], description: str, argv_options: list[str] - ) -> Optional[TRootConfig]: + ) -> TRootConfig | None: """Parse the commandline and config files Supports generation of config files, so is used for the main homeserver app. @@ -898,7 +896,7 @@ class RootConfig: :returns: the previous config object, which no longer has a reference to this RootConfig. """ - existing_config: Optional[Config] = getattr(self, section_name, None) + existing_config: Config | None = getattr(self, section_name, None) if existing_config is None: raise ValueError(f"Unknown config section '{section_name}'") logger.info("Reloading config section '%s'", section_name) diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 1a9cb7db47..fe9b3333c4 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -6,9 +6,7 @@ from typing import ( Iterator, Literal, MutableMapping, - Optional, TypeVar, - Union, overload, ) @@ -64,7 +62,7 @@ from synapse.config import ( # noqa: F401 from synapse.types import StrSequence class ConfigError(Exception): - def __init__(self, msg: str, path: Optional[StrSequence] = None): + def __init__(self, msg: str, path: StrSequence | None = None): self.msg = msg self.path = path @@ -146,16 +144,16 @@ class RootConfig: data_dir_path: str, server_name: str, generate_secrets: bool = ..., - report_stats: Optional[bool] = ..., + report_stats: bool | None = ..., open_private_ports: bool = ..., - listeners: Optional[Any] = ..., - tls_certificate_path: Optional[str] = ..., - tls_private_key_path: Optional[str] = ..., + listeners: Any | None = ..., + tls_certificate_path: str | None = ..., + tls_private_key_path: str | None = ..., ) -> str: ... @classmethod def load_or_generate_config( cls: type[TRootConfig], description: str, argv_options: list[str] - ) -> Optional[TRootConfig]: ... + ) -> TRootConfig | None: ... @classmethod def load_config( cls: type[TRootConfig], description: str, argv_options: list[str] @@ -183,11 +181,11 @@ class Config: default_template_dir: str def __init__(self, root_config: RootConfig = ...) -> None: ... @staticmethod - def parse_size(value: Union[str, int]) -> int: ... + def parse_size(value: str | int) -> int: ... @staticmethod - def parse_duration(value: Union[str, int]) -> int: ... + def parse_duration(value: str | int) -> int: ... @staticmethod - def abspath(file_path: Optional[str]) -> str: ... + def abspath(file_path: str | None) -> str: ... @classmethod def path_exists(cls, file_path: str) -> bool: ... @classmethod @@ -200,7 +198,7 @@ class Config: def read_templates( self, filenames: list[str], - custom_template_directories: Optional[Iterable[str]] = None, + custom_template_directories: Iterable[str] | None = None, ) -> list[jinja2.Template]: ... def read_config_files(config_files: Iterable[str]) -> dict[str, Any]: ... diff --git a/synapse/config/api.py b/synapse/config/api.py index e32e03e55e..03b92249a9 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -20,7 +20,7 @@ # import logging -from typing import Any, Iterable, Optional +from typing import Any, Iterable from synapse.api.constants import EventTypes from synapse.config._base import Config, ConfigError @@ -46,7 +46,7 @@ class ApiConfig(Config): def _get_prejoin_state_entries( self, config: JsonDict - ) -> Iterable[tuple[str, Optional[str]]]: + ) -> Iterable[tuple[str, str | None]]: """Get the event types and state keys to include in the prejoin state.""" room_prejoin_state_config = config.get("room_prejoin_state") or {} diff --git a/synapse/config/cache.py b/synapse/config/cache.py index e51efc3dbd..c9ce826e1a 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -23,7 +23,7 @@ import logging import os import re import threading -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Mapping import attr @@ -53,7 +53,7 @@ class CacheProperties: default_factor_size: float = float( os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) ) - resize_all_caches_func: Optional[Callable[[], None]] = None + resize_all_caches_func: Callable[[], None] | None = None properties = CacheProperties() @@ -107,7 +107,7 @@ class CacheConfig(Config): cache_factors: dict[str, float] global_factor: float track_memory_usage: bool - expiry_time_msec: Optional[int] + expiry_time_msec: int | None sync_response_cache_duration: int @staticmethod diff --git a/synapse/config/cas.py b/synapse/config/cas.py index e6e869bb16..dc5be7ccf1 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -20,7 +20,7 @@ # # -from typing import Any, Optional +from typing import Any from synapse.config.sso import SsoAttributeRequirement from synapse.types import JsonDict @@ -49,7 +49,7 @@ class CasConfig(Config): # TODO Update this to a _synapse URL. public_baseurl = self.root.server.public_baseurl - self.cas_service_url: Optional[str] = ( + self.cas_service_url: str | None = ( public_baseurl + "_matrix/client/r0/login/cas/ticket" ) diff --git a/synapse/config/consent.py b/synapse/config/consent.py index 7dc80d4cf8..35484ee033 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py @@ -19,7 +19,7 @@ # from os import path -from typing import Any, Optional +from typing import Any from synapse.config import ConfigError from synapse.types import JsonDict @@ -33,11 +33,11 @@ class ConsentConfig(Config): def __init__(self, *args: Any): super().__init__(*args) - self.user_consent_version: Optional[str] = None - self.user_consent_template_dir: Optional[str] = None - self.user_consent_server_notice_content: Optional[JsonDict] = None + self.user_consent_version: str | None = None + self.user_consent_template_dir: str | None = None + self.user_consent_server_notice_content: JsonDict | None = None self.user_consent_server_notice_to_guests = False - self.block_events_without_consent_error: Optional[str] = None + self.block_events_without_consent_error: str | None = None self.user_consent_at_registration = False self.user_consent_policy_name = "Privacy Policy" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f82e8572f2..52c3ec0da2 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -59,7 +59,7 @@ class ClientAuthMethod(enum.Enum): PRIVATE_KEY_JWT = "private_key_jwt" -def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]: +def _parse_jwks(jwks: JsonDict | None) -> Optional["JsonWebKey"]: """A helper function to parse a JWK dict into a JsonWebKey.""" if jwks is None: @@ -71,7 +71,7 @@ def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]: def _check_client_secret( - instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str] + instance: "MSC3861", _attribute: attr.Attribute, _value: str | None ) -> None: if instance._client_secret and instance._client_secret_path: raise ConfigError( @@ -88,7 +88,7 @@ def _check_client_secret( def _check_admin_token( - instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str] + instance: "MSC3861", _attribute: attr.Attribute, _value: str | None ) -> None: if instance._admin_token and instance._admin_token_path: raise ConfigError( @@ -124,7 +124,7 @@ class MSC3861: issuer: str = attr.ib(default="", validator=attr.validators.instance_of(str)) """The URL of the OIDC Provider.""" - issuer_metadata: Optional[JsonDict] = attr.ib(default=None) + issuer_metadata: JsonDict | None = attr.ib(default=None) """The issuer metadata to use, otherwise discovered from /.well-known/openid-configuration as per MSC2965.""" client_id: str = attr.ib( @@ -138,7 +138,7 @@ class MSC3861: ) """The auth method used when calling the introspection endpoint.""" - _client_secret: Optional[str] = attr.ib( + _client_secret: str | None = attr.ib( default=None, validator=[ attr.validators.optional(attr.validators.instance_of(str)), @@ -150,7 +150,7 @@ class MSC3861: when using any of the client_secret_* client auth methods. """ - _client_secret_path: Optional[str] = attr.ib( + _client_secret_path: str | None = attr.ib( default=None, validator=[ attr.validators.optional(attr.validators.instance_of(str)), @@ -196,19 +196,19 @@ class MSC3861: ("experimental", "msc3861", "client_auth_method"), ) - introspection_endpoint: Optional[str] = attr.ib( + introspection_endpoint: str | None = attr.ib( default=None, validator=attr.validators.optional(attr.validators.instance_of(str)), ) """The URL of the introspection endpoint used to validate access tokens.""" - account_management_url: Optional[str] = attr.ib( + account_management_url: str | None = attr.ib( default=None, validator=attr.validators.optional(attr.validators.instance_of(str)), ) """The URL of the My Account page on the OIDC Provider as per MSC2965.""" - _admin_token: Optional[str] = attr.ib( + _admin_token: str | None = attr.ib( default=None, validator=[ attr.validators.optional(attr.validators.instance_of(str)), @@ -220,7 +220,7 @@ class MSC3861: This is used by the OIDC provider, to make admin calls to Synapse. """ - _admin_token_path: Optional[str] = attr.ib( + _admin_token_path: str | None = attr.ib( default=None, validator=[ attr.validators.optional(attr.validators.instance_of(str)), @@ -232,7 +232,7 @@ class MSC3861: external file. """ - def client_secret(self) -> Optional[str]: + def client_secret(self) -> str | None: """Returns the secret given via `client_secret` or `client_secret_path`.""" if self._client_secret_path: return read_secret_from_file_once( @@ -241,7 +241,7 @@ class MSC3861: ) return self._client_secret - def admin_token(self) -> Optional[str]: + def admin_token(self) -> str | None: """Returns the admin token given via `admin_token` or `admin_token_path`.""" if self._admin_token_path: return read_secret_from_file_once( @@ -526,7 +526,7 @@ class ExperimentalConfig(Config): # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code self.msc4108_enabled = experimental.get("msc4108_enabled", False) - self.msc4108_delegation_endpoint: Optional[str] = experimental.get( + self.msc4108_delegation_endpoint: str | None = experimental.get( "msc4108_delegation_endpoint", None ) diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 31f46e420d..ad0bd56a80 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Any, Optional +from typing import Any from synapse.config._base import Config from synapse.config._util import validate_config @@ -32,7 +32,7 @@ class FederationConfig(Config): federation_config = config.setdefault("federation", {}) # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist: Optional[dict] = None + self.federation_domain_whitelist: dict | None = None federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: diff --git a/synapse/config/key.py b/synapse/config/key.py index 3e832b4946..bfeeac5e30 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -23,7 +23,7 @@ import hashlib import logging import os -from typing import TYPE_CHECKING, Any, Iterator, Optional +from typing import TYPE_CHECKING, Any, Iterator import attr import jsonschema @@ -110,7 +110,7 @@ class TrustedKeyServer: server_name: str # map from key id to key object, or None to disable signature verification. - verify_keys: Optional[dict[str, VerifyKey]] = None + verify_keys: dict[str, VerifyKey] | None = None class KeyConfig(Config): @@ -219,7 +219,7 @@ class KeyConfig(Config): if form_secret_path: if form_secret: raise ConfigError(CONFLICTING_FORM_SECRET_OPTS_ERROR) - self.form_secret: Optional[str] = read_file( + self.form_secret: str | None = read_file( form_secret_path, ("form_secret_path",) ).strip() else: @@ -279,7 +279,7 @@ class KeyConfig(Config): raise ConfigError("Error reading %s: %s" % (name, str(e))) def read_old_signing_keys( - self, old_signing_keys: Optional[JsonDict] + self, old_signing_keys: JsonDict | None ) -> dict[str, "VerifyKeyWithExpiry"]: if old_signing_keys is None: return {} @@ -408,7 +408,7 @@ def _parse_key_servers( server_name = server["server_name"] result = TrustedKeyServer(server_name=server_name) - verify_keys: Optional[dict[str, str]] = server.get("verify_keys") + verify_keys: dict[str, str] | None = server.get("verify_keys") if verify_keys is not None: result.verify_keys = {} for key_id, key_base64 in verify_keys.items(): diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 1f5c6da3ae..4af73627be 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -26,7 +26,7 @@ import os import sys import threading from string import Template -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import yaml from zope.interface import implementer @@ -280,7 +280,7 @@ def one_time_logging_setup(*, logBeginner: LogBeginner = globalLogBeginner) -> N def _setup_stdlib_logging( - config: "HomeServerConfig", log_config_path: Optional[str] + config: "HomeServerConfig", log_config_path: str | None ) -> None: """ Set up Python standard library logging. @@ -327,7 +327,7 @@ def _load_logging_config(log_config_path: str) -> None: reset_logging_config() -def _reload_logging_config(log_config_path: Optional[str]) -> None: +def _reload_logging_config(log_config_path: str | None) -> None: """ Reload the log configuration from the file and apply it. """ diff --git a/synapse/config/mas.py b/synapse/config/mas.py index 53cf500e95..dd982589a8 100644 --- a/synapse/config/mas.py +++ b/synapse/config/mas.py @@ -13,7 +13,7 @@ # # -from typing import Any, Optional +from typing import Any from pydantic import ( AnyHttpUrl, @@ -36,8 +36,8 @@ from ._base import Config, ConfigError, RootConfig class MasConfigModel(ParseModel): enabled: StrictBool = False endpoint: AnyHttpUrl = AnyHttpUrl("http://localhost:8080") - secret: Optional[StrictStr] = Field(default=None) - secret_path: Optional[FilePath] = Field(default=None) + secret: StrictStr | None = Field(default=None) + secret_path: FilePath | None = Field(default=None) @model_validator(mode="after") def verify_secret(self) -> Self: diff --git a/synapse/config/matrixrtc.py b/synapse/config/matrixrtc.py index 74fd7cad81..84c245e286 100644 --- a/synapse/config/matrixrtc.py +++ b/synapse/config/matrixrtc.py @@ -15,7 +15,7 @@ # # -from typing import Any, Optional +from typing import Any from pydantic import Field, StrictStr, ValidationError, model_validator from typing_extensions import Self @@ -29,7 +29,7 @@ from ._base import Config, ConfigError class TransportConfigModel(ParseModel): type: StrictStr - livekit_service_url: Optional[StrictStr] = Field(default=None) + livekit_service_url: StrictStr | None = Field(default=None) """An optional livekit service URL. Only required if type is "livekit".""" @model_validator(mode="after") diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 8a4ded62ef..83dbee53b6 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -20,7 +20,7 @@ # # -from typing import Any, Optional +from typing import Any import attr @@ -75,7 +75,7 @@ class MetricsConfig(Config): ) def generate_config_section( - self, report_stats: Optional[bool] = None, **kwargs: Any + self, report_stats: bool | None = None, **kwargs: Any ) -> str: if report_stats is not None: res = "report_stats: %s\n" % ("true" if report_stats else "false") diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py index a4a192302c..208f86374b 100644 --- a/synapse/config/oembed.py +++ b/synapse/config/oembed.py @@ -21,7 +21,7 @@ import importlib.resources as importlib_resources import json import re -from typing import Any, Iterable, Optional, Pattern +from typing import Any, Iterable, Pattern from urllib import parse as urlparse import attr @@ -39,7 +39,7 @@ class OEmbedEndpointConfig: # The patterns to match. url_patterns: list[Pattern[str]] # The supported formats. - formats: Optional[list[str]] + formats: list[str] | None class OembedConfig(Config): diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index ada89bb8bc..73fe6891cd 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -21,7 +21,7 @@ # from collections import Counter -from typing import Any, Collection, Iterable, Mapping, Optional +from typing import Any, Collection, Iterable, Mapping import attr @@ -276,7 +276,7 @@ def _parse_oidc_config_dict( ) from e client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") - client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None + client_secret_jwt_key: OidcProviderClientSecretJwtKey | None = None if client_secret_jwt_key_config is not None: keyfile = client_secret_jwt_key_config.get("key_file") if keyfile: @@ -384,10 +384,10 @@ class OidcProviderConfig: idp_name: str # Optional MXC URI for icon for this IdP. - idp_icon: Optional[str] + idp_icon: str | None # Optional brand identifier for this IdP. - idp_brand: Optional[str] + idp_brand: str | None # whether the OIDC discovery mechanism is used to discover endpoints discover: bool @@ -401,11 +401,11 @@ class OidcProviderConfig: # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate # a secret. - client_secret: Optional[str] + client_secret: str | None # key to use to construct a JWT to use as a client secret. May be `None` if # `client_secret` is set. - client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] + client_secret_jwt_key: OidcProviderClientSecretJwtKey | None # auth method to use when exchanging the token. # Valid values are 'client_secret_basic', 'client_secret_post' and @@ -416,7 +416,7 @@ class OidcProviderConfig: # Valid values are 'auto', 'always', and 'never'. pkce_method: str - id_token_signing_alg_values_supported: Optional[list[str]] + id_token_signing_alg_values_supported: list[str] | None """ List of the JWS signing algorithms (`alg` values) that are supported for signing the `id_token`. @@ -448,18 +448,18 @@ class OidcProviderConfig: scopes: Collection[str] # the oauth2 authorization endpoint. Required if discovery is disabled. - authorization_endpoint: Optional[str] + authorization_endpoint: str | None # the oauth2 token endpoint. Required if discovery is disabled. - token_endpoint: Optional[str] + token_endpoint: str | None # the OIDC userinfo endpoint. Required if discovery is disabled and the # "openid" scope is not requested. - userinfo_endpoint: Optional[str] + userinfo_endpoint: str | None # URI where to fetch the JWKS. Required if discovery is disabled and the # "openid" scope is used. - jwks_uri: Optional[str] + jwks_uri: str | None # Whether Synapse should react to backchannel logouts backchannel_logout_enabled: bool @@ -474,7 +474,7 @@ class OidcProviderConfig: # values are: "auto" or "userinfo_endpoint". user_profile_method: str - redirect_uri: Optional[str] + redirect_uri: str | None """ An optional replacement for Synapse's hardcoded `redirect_uri` URL (`/_synapse/client/oidc/callback`). This can be used to send diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index be2f49f87c..78d9d61d3c 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -19,7 +19,7 @@ # # -from typing import Any, Optional, cast +from typing import Any, cast import attr @@ -39,7 +39,7 @@ class RatelimitSettings: cls, config: dict[str, Any], key: str, - defaults: Optional[dict[str, float]] = None, + defaults: dict[str, float] | None = None, ) -> "RatelimitSettings": """Parse config[key] as a new-style rate limiter config. diff --git a/synapse/config/registration.py b/synapse/config/registration.py index c0e7316bc3..7f7a224e02 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -20,7 +20,7 @@ # # import argparse -from typing import Any, Optional +from typing import Any from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError, read_file @@ -181,7 +181,7 @@ class RegistrationConfig(Config): refreshable_access_token_lifetime = self.parse_duration( refreshable_access_token_lifetime ) - self.refreshable_access_token_lifetime: Optional[int] = ( + self.refreshable_access_token_lifetime: int | None = ( refreshable_access_token_lifetime ) @@ -226,7 +226,7 @@ class RegistrationConfig(Config): refresh_token_lifetime = config.get("refresh_token_lifetime") if refresh_token_lifetime is not None: refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) - self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime + self.refresh_token_lifetime: int | None = refresh_token_lifetime if ( self.session_lifetime is not None diff --git a/synapse/config/retention.py b/synapse/config/retention.py index 9d34f1e241..ab80ac214d 100644 --- a/synapse/config/retention.py +++ b/synapse/config/retention.py @@ -20,7 +20,7 @@ # import logging -from typing import Any, Optional +from typing import Any import attr @@ -35,8 +35,8 @@ class RetentionPurgeJob: """Object describing the configuration of the manhole""" interval: int - shortest_max_lifetime: Optional[int] - longest_max_lifetime: Optional[int] + shortest_max_lifetime: int | None + longest_max_lifetime: int | None class RetentionConfig(Config): diff --git a/synapse/config/server.py b/synapse/config/server.py index 662ed24a13..495f289159 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -25,7 +25,7 @@ import logging import os.path import urllib.parse from textwrap import indent -from typing import Any, Iterable, Optional, TypedDict, Union +from typing import Any, Iterable, TypedDict from urllib.request import getproxies_environment import attr @@ -95,9 +95,9 @@ def _6to4(network: IPNetwork) -> IPNetwork: def generate_ip_set( - ip_addresses: Optional[Iterable[str]], - extra_addresses: Optional[Iterable[str]] = None, - config_path: Optional[StrSequence] = None, + ip_addresses: Iterable[str] | None, + extra_addresses: Iterable[str] | None = None, + config_path: StrSequence | None = None, ) -> IPSet: """ Generate an IPSet from a list of IP addresses or CIDRs. @@ -230,8 +230,8 @@ class HttpListenerConfig: x_forwarded: bool = False resources: list[HttpResourceConfig] = attr.Factory(list) additional_resources: dict[str, dict] = attr.Factory(dict) - tag: Optional[str] = None - request_id_header: Optional[str] = None + tag: str | None = None + request_id_header: str | None = None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -244,7 +244,7 @@ class TCPListenerConfig: tls: bool = False # http_options is only populated if type=http - http_options: Optional[HttpListenerConfig] = None + http_options: HttpListenerConfig | None = None def get_site_tag(self) -> str: """Retrieves http_options.tag if it exists, otherwise the port number.""" @@ -269,7 +269,7 @@ class UnixListenerConfig: type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES)) # http_options is only populated if type=http - http_options: Optional[HttpListenerConfig] = None + http_options: HttpListenerConfig | None = None def get_site_tag(self) -> str: return "unix" @@ -279,7 +279,7 @@ class UnixListenerConfig: return False -ListenerConfig = Union[TCPListenerConfig, UnixListenerConfig] +ListenerConfig = TCPListenerConfig | UnixListenerConfig @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -288,14 +288,14 @@ class ManholeConfig: username: str = attr.ib(validator=attr.validators.instance_of(str)) password: str = attr.ib(validator=attr.validators.instance_of(str)) - priv_key: Optional[Key] - pub_key: Optional[Key] + priv_key: Key | None + pub_key: Key | None @attr.s(frozen=True) class LimitRemoteRoomsConfig: enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False) - complexity: Union[float, int] = attr.ib( + complexity: float | int = attr.ib( validator=attr.validators.instance_of((float, int)), # noqa default=1.0, ) @@ -313,11 +313,11 @@ class ProxyConfigDictionary(TypedDict): Dictionary of proxy settings suitable for interacting with `urllib.request` API's """ - http: Optional[str] + http: str | None """ Proxy server to use for HTTP requests. """ - https: Optional[str] + https: str | None """ Proxy server to use for HTTPS requests. """ @@ -336,15 +336,15 @@ class ProxyConfig: Synapse configuration for HTTP proxy settings. """ - http_proxy: Optional[str] + http_proxy: str | None """ Proxy server to use for HTTP requests. """ - https_proxy: Optional[str] + https_proxy: str | None """ Proxy server to use for HTTPS requests. """ - no_proxy_hosts: Optional[list[str]] + no_proxy_hosts: list[str] | None """ List of hosts, IP addresses, or IP ranges in CIDR format which should not use the proxy. Synapse will directly connect to these hosts. @@ -607,7 +607,7 @@ class ServerConfig(Config): # before redacting them. redaction_retention_period = config.get("redaction_retention_period", "7d") if redaction_retention_period is not None: - self.redaction_retention_period: Optional[int] = self.parse_duration( + self.redaction_retention_period: int | None = self.parse_duration( redaction_retention_period ) else: @@ -618,7 +618,7 @@ class ServerConfig(Config): "forgotten_room_retention_period", None ) if forgotten_room_retention_period is not None: - self.forgotten_room_retention_period: Optional[int] = self.parse_duration( + self.forgotten_room_retention_period: int | None = self.parse_duration( forgotten_room_retention_period ) else: @@ -627,7 +627,7 @@ class ServerConfig(Config): # How long to keep entries in the `users_ips` table. user_ips_max_age = config.get("user_ips_max_age", "28d") if user_ips_max_age is not None: - self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age) + self.user_ips_max_age: int | None = self.parse_duration(user_ips_max_age) else: self.user_ips_max_age = None @@ -864,11 +864,11 @@ class ServerConfig(Config): ) # Whitelist of domain names that given next_link parameters must have - next_link_domain_whitelist: Optional[list[str]] = config.get( + next_link_domain_whitelist: list[str] | None = config.get( "next_link_domain_whitelist" ) - self.next_link_domain_whitelist: Optional[set[str]] = None + self.next_link_domain_whitelist: set[str] | None = None if next_link_domain_whitelist is not None: if not isinstance(next_link_domain_whitelist, list): raise ConfigError("'next_link_domain_whitelist' must be a list") @@ -880,7 +880,7 @@ class ServerConfig(Config): if not isinstance(templates_config, dict): raise ConfigError("The 'templates' section must be a dictionary") - self.custom_template_directory: Optional[str] = templates_config.get( + self.custom_template_directory: str | None = templates_config.get( "custom_template_directory" ) if self.custom_template_directory is not None and not isinstance( @@ -896,12 +896,12 @@ class ServerConfig(Config): config.get("exclude_rooms_from_sync") or [] ) - delete_stale_devices_after: Optional[str] = ( + delete_stale_devices_after: str | None = ( config.get("delete_stale_devices_after") or None ) if delete_stale_devices_after is not None: - self.delete_stale_devices_after: Optional[int] = self.parse_duration( + self.delete_stale_devices_after: int | None = self.parse_duration( delete_stale_devices_after ) else: @@ -910,7 +910,7 @@ class ServerConfig(Config): # The maximum allowed delay duration for delayed events (MSC4140). max_event_delay_duration = config.get("max_event_delay_duration") if max_event_delay_duration is not None: - self.max_event_delay_ms: Optional[int] = self.parse_duration( + self.max_event_delay_ms: int | None = self.parse_duration( max_event_delay_duration ) if self.max_event_delay_ms <= 0: @@ -927,7 +927,7 @@ class ServerConfig(Config): data_dir_path: str, server_name: str, open_private_ports: bool, - listeners: Optional[list[dict]], + listeners: list[dict] | None, **kwargs: Any, ) -> str: _, bind_port = parse_and_validate_server_name(server_name) @@ -1028,7 +1028,7 @@ class ServerConfig(Config): help="Turn on the twisted telnet manhole service on the given port.", ) - def read_gc_intervals(self, durations: Any) -> Optional[tuple[float, float, float]]: + def read_gc_intervals(self, durations: Any) -> tuple[float, float, float] | None: """Reads the three durations for the GC min interval option, returning seconds.""" if durations is None: return None @@ -1066,8 +1066,8 @@ def is_threepid_reserved( def read_gc_thresholds( - thresholds: Optional[list[Any]], -) -> Optional[tuple[int, int, int]]: + thresholds: list[Any] | None, +) -> tuple[int, int, int] | None: """Reads the three integer thresholds for garbage collection. Ensures that the thresholds are integers if thresholds are supplied. """ diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py index 4de2d62b54..d19e2569a1 100644 --- a/synapse/config/server_notices.py +++ b/synapse/config/server_notices.py @@ -18,7 +18,7 @@ # # -from typing import Any, Optional +from typing import Any from synapse.types import JsonDict, UserID @@ -58,12 +58,12 @@ class ServerNoticesConfig(Config): def __init__(self, *args: Any): super().__init__(*args) - self.server_notices_mxid: Optional[str] = None - self.server_notices_mxid_display_name: Optional[str] = None - self.server_notices_mxid_avatar_url: Optional[str] = None - self.server_notices_room_name: Optional[str] = None - self.server_notices_room_avatar_url: Optional[str] = None - self.server_notices_room_topic: Optional[str] = None + self.server_notices_mxid: str | None = None + self.server_notices_mxid_display_name: str | None = None + self.server_notices_mxid_avatar_url: str | None = None + self.server_notices_room_name: str | None = None + self.server_notices_room_avatar_url: str | None = None + self.server_notices_room_topic: str | None = None self.server_notices_auto_join: bool = False def read_config(self, config: JsonDict, **kwargs: Any) -> None: diff --git a/synapse/config/sso.py b/synapse/config/sso.py index facb418510..1d08bef868 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -19,7 +19,7 @@ # # import logging -from typing import Any, Optional +from typing import Any import attr @@ -44,8 +44,8 @@ class SsoAttributeRequirement: attribute: str # If neither `value` nor `one_of` is given, the attribute must simply exist. - value: Optional[str] = None - one_of: Optional[list[str]] = None + value: str | None = None + one_of: list[str] | None = None JSON_SCHEMA = { "type": "object", diff --git a/synapse/config/tls.py b/synapse/config/tls.py index d03a77d9d2..de4d676e08 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -20,7 +20,7 @@ # import logging -from typing import Any, Optional, Pattern +from typing import Any, Pattern from matrix_common.regex import glob_to_regex @@ -135,8 +135,8 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - self.tls_certificate: Optional[crypto.X509] = None - self.tls_private_key: Optional[crypto.PKey] = None + self.tls_certificate: crypto.X509 | None = None + self.tls_private_key: crypto.PKey | None = None def read_certificate_from_disk(self) -> None: """ @@ -147,8 +147,8 @@ class TlsConfig(Config): def generate_config_section( self, - tls_certificate_path: Optional[str], - tls_private_key_path: Optional[str], + tls_certificate_path: str | None, + tls_private_key_path: str | None, **kwargs: Any, ) -> str: """If the TLS paths are not specified the default will be certs in the diff --git a/synapse/config/user_types.py b/synapse/config/user_types.py index dd64425d6c..e47713b7f4 100644 --- a/synapse/config/user_types.py +++ b/synapse/config/user_types.py @@ -12,7 +12,7 @@ # . # -from typing import Any, Optional +from typing import Any from synapse.api.constants import UserTypes from synapse.types import JsonDict @@ -26,9 +26,7 @@ class UserTypesConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: user_types: JsonDict = config.get("user_types", {}) - self.default_user_type: Optional[str] = user_types.get( - "default_user_type", None - ) + self.default_user_type: str | None = user_types.get("default_user_type", None) self.extra_user_types: list[str] = user_types.get("extra_user_types", []) all_user_types: list[str] = [] diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 90f8c72412..ec8ab9506b 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -22,7 +22,7 @@ import argparse import logging -from typing import Any, Optional, Union +from typing import Any import attr from pydantic import ( @@ -79,7 +79,7 @@ MAIN_PROCESS_INSTANCE_MAP_NAME = "main" logger = logging.getLogger(__name__) -def _instance_to_list_converter(obj: Union[str, list[str]]) -> list[str]: +def _instance_to_list_converter(obj: str | list[str]) -> list[str]: """Helper for allowing parsing a string or list of strings to a config option expecting a list of strings. """ @@ -119,7 +119,7 @@ class InstanceUnixLocationConfig(ParseModel): return f"{self.path}" -InstanceLocationConfig = Union[InstanceTcpLocationConfig, InstanceUnixLocationConfig] +InstanceLocationConfig = InstanceTcpLocationConfig | InstanceUnixLocationConfig @attr.s @@ -190,7 +190,7 @@ class OutboundFederationRestrictedTo: locations: list of instance locations to connect to proxy via. """ - instances: Optional[list[str]] + instances: list[str] | None locations: list[InstanceLocationConfig] = attr.Factory(list) def __contains__(self, instance: str) -> bool: @@ -246,7 +246,7 @@ class WorkerConfig(Config): if worker_replication_secret_path: if worker_replication_secret: raise ConfigError(CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR) - self.worker_replication_secret: Optional[str] = read_file( + self.worker_replication_secret: str | None = read_file( worker_replication_secret_path, ("worker_replication_secret_path",) ).strip() else: @@ -341,7 +341,7 @@ class WorkerConfig(Config): % MAIN_PROCESS_INSTANCE_MAP_NAME ) - # type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently + # type-ignore: the expression `A | B` is not a `type[A | B]` currently self.instance_map: dict[str, InstanceLocationConfig] = ( parse_and_validate_mapping( instance_map, diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 24a693fdb1..3abb644df5 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -21,7 +21,7 @@ import abc import logging -from typing import TYPE_CHECKING, Callable, Iterable, Optional +from typing import TYPE_CHECKING, Callable, Iterable import attr from signedjson.key import ( @@ -150,7 +150,7 @@ class Keyring: """ def __init__( - self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None + self, hs: "HomeServer", key_fetchers: "Iterable[KeyFetcher] | None" = None ): self.server_name = hs.hostname diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5d927a925a..66f50115e3 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -160,7 +160,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: async def check_state_independent_auth_rules( store: _EventSourceStore, event: "EventBase", - batched_auth_events: Optional[Mapping[str, "EventBase"]] = None, + batched_auth_events: Mapping[str, "EventBase"] | None = None, ) -> None: """Check that an event complies with auth rules that are independent of room state @@ -788,7 +788,7 @@ def _check_joined_room( def get_send_level( - etype: str, state_key: Optional[str], power_levels_event: Optional["EventBase"] + etype: str, state_key: str | None, power_levels_event: Optional["EventBase"] ) -> int: """Get the power level required to send an event of a given type @@ -989,7 +989,7 @@ def _check_power_levels( user_level = get_user_power_level(event.user_id, auth_events) # Check other levels: - levels_to_check: list[tuple[str, Optional[str]]] = [ + levels_to_check: list[tuple[str, str | None]] = [ ("users_default", None), ("events_default", None), ("state_default", None), @@ -1027,12 +1027,12 @@ def _check_power_levels( new_loc = new_loc.get(dir, {}) if level_to_check in old_loc: - old_level: Optional[int] = int(old_loc[level_to_check]) + old_level: int | None = int(old_loc[level_to_check]) else: old_level = None if level_to_check in new_loc: - new_level: Optional[int] = int(new_loc[level_to_check]) + new_level: int | None = int(new_loc[level_to_check]) else: new_level = None diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index a353076e0d..5f78603782 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -28,7 +28,6 @@ from typing import ( Generic, Iterable, Literal, - Optional, TypeVar, Union, overload, @@ -90,21 +89,21 @@ class DictProperty(Generic[T]): def __get__( self, instance: Literal[None], - owner: Optional[type[_DictPropertyInstance]] = None, + owner: type[_DictPropertyInstance] | None = None, ) -> "DictProperty": ... @overload def __get__( self, instance: _DictPropertyInstance, - owner: Optional[type[_DictPropertyInstance]] = None, + owner: type[_DictPropertyInstance] | None = None, ) -> T: ... def __get__( self, - instance: Optional[_DictPropertyInstance], - owner: Optional[type[_DictPropertyInstance]] = None, - ) -> Union[T, "DictProperty"]: + instance: _DictPropertyInstance | None, + owner: type[_DictPropertyInstance] | None = None, + ) -> T | "DictProperty": # if the property is accessed as a class property rather than an instance # property, return the property itself rather than the value if instance is None: @@ -156,21 +155,21 @@ class DefaultDictProperty(DictProperty, Generic[T]): def __get__( self, instance: Literal[None], - owner: Optional[type[_DictPropertyInstance]] = None, + owner: type[_DictPropertyInstance] | None = None, ) -> "DefaultDictProperty": ... @overload def __get__( self, instance: _DictPropertyInstance, - owner: Optional[type[_DictPropertyInstance]] = None, + owner: type[_DictPropertyInstance] | None = None, ) -> T: ... def __get__( self, - instance: Optional[_DictPropertyInstance], - owner: Optional[type[_DictPropertyInstance]] = None, - ) -> Union[T, "DefaultDictProperty"]: + instance: _DictPropertyInstance | None, + owner: type[_DictPropertyInstance] | None = None, + ) -> T | "DefaultDictProperty": if instance is None: return self assert isinstance(instance, EventBase) @@ -191,7 +190,7 @@ class EventBase(metaclass=abc.ABCMeta): signatures: dict[str, dict[str, str]], unsigned: JsonDict, internal_metadata_dict: JsonDict, - rejected_reason: Optional[str], + rejected_reason: str | None, ): assert room_version.event_format == self.format_version @@ -209,7 +208,7 @@ class EventBase(metaclass=abc.ABCMeta): hashes: DictProperty[dict[str, str]] = DictProperty("hashes") origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts") sender: DictProperty[str] = DictProperty("sender") - # TODO state_key should be Optional[str]. This is generally asserted in Synapse + # TODO state_key should be str | None. This is generally asserted in Synapse # by calling is_state() first (which ensures it is not None), but it is hard (not possible?) # to properly annotate that calling is_state() asserts that state_key exists # and is non-None. It would be better to replace such direct references with @@ -231,7 +230,7 @@ class EventBase(metaclass=abc.ABCMeta): return self.content["membership"] @property - def redacts(self) -> Optional[str]: + def redacts(self) -> str | None: """MSC2176 moved the redacts field into the content.""" if self.room_version.updated_redaction_rules: return self.content.get("redacts") @@ -240,7 +239,7 @@ class EventBase(metaclass=abc.ABCMeta): def is_state(self) -> bool: return self.get_state_key() is not None - def get_state_key(self) -> Optional[str]: + def get_state_key(self) -> str | None: """Get the state key of this event, or None if it's not a state event""" return self._dict.get("state_key") @@ -250,13 +249,13 @@ class EventBase(metaclass=abc.ABCMeta): return d - def get(self, key: str, default: Optional[Any] = None) -> Any: + def get(self, key: str, default: Any | None = None) -> Any: return self._dict.get(key, default) def get_internal_metadata_dict(self) -> JsonDict: return self.internal_metadata.get_dict() - def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict: + def get_pdu_json(self, time_now: int | None = None) -> JsonDict: pdu_json = self.get_dict() if time_now is not None and "age_ts" in pdu_json["unsigned"]: @@ -283,13 +282,13 @@ class EventBase(metaclass=abc.ABCMeta): return template_json - def __getitem__(self, field: str) -> Optional[Any]: + def __getitem__(self, field: str) -> Any | None: return self._dict[field] def __contains__(self, field: str) -> bool: return field in self._dict - def items(self) -> list[tuple[str, Optional[Any]]]: + def items(self) -> list[tuple[str, Any | None]]: return list(self._dict.items()) def keys(self) -> Iterable[str]: @@ -348,8 +347,8 @@ class FrozenEvent(EventBase): self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: Optional[JsonDict] = None, - rejected_reason: Optional[str] = None, + internal_metadata_dict: JsonDict | None = None, + rejected_reason: str | None = None, ): internal_metadata_dict = internal_metadata_dict or {} @@ -400,8 +399,8 @@ class FrozenEventV2(EventBase): self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: Optional[JsonDict] = None, - rejected_reason: Optional[str] = None, + internal_metadata_dict: JsonDict | None = None, + rejected_reason: str | None = None, ): internal_metadata_dict = internal_metadata_dict or {} @@ -427,7 +426,7 @@ class FrozenEventV2(EventBase): else: frozen_dict = event_dict - self._event_id: Optional[str] = None + self._event_id: str | None = None super().__init__( frozen_dict, @@ -502,8 +501,8 @@ class FrozenEventV4(FrozenEventV3): self, event_dict: JsonDict, room_version: RoomVersion, - internal_metadata_dict: Optional[JsonDict] = None, - rejected_reason: Optional[str] = None, + internal_metadata_dict: JsonDict | None = None, + rejected_reason: str | None = None, ): super().__init__( event_dict=event_dict, @@ -511,7 +510,7 @@ class FrozenEventV4(FrozenEventV3): internal_metadata_dict=internal_metadata_dict, rejected_reason=rejected_reason, ) - self._room_id: Optional[str] = None + self._room_id: str | None = None @property def room_id(self) -> str: @@ -554,7 +553,7 @@ class FrozenEventV4(FrozenEventV3): def _event_type_from_format_version( format_version: int, -) -> type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]: +) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3]: """Returns the python type to use to construct an Event object for the given event format version. @@ -580,8 +579,8 @@ def _event_type_from_format_version( def make_event_from_dict( event_dict: JsonDict, room_version: RoomVersion = RoomVersions.V1, - internal_metadata_dict: Optional[JsonDict] = None, - rejected_reason: Optional[str] = None, + internal_metadata_dict: JsonDict | None = None, + rejected_reason: str | None = None, ) -> EventBase: """Construct an EventBase from the given event dict""" event_type = _event_type_from_format_version(room_version.event_format) @@ -598,10 +597,10 @@ class _EventRelation: rel_type: str # The aggregation key. Will be None if the rel_type is not m.annotation or is # not a string. - aggregation_key: Optional[str] + aggregation_key: str | None -def relation_from_event(event: EventBase) -> Optional[_EventRelation]: +def relation_from_event(event: EventBase) -> _EventRelation | None: """ Attempt to parse relation information an event. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index a57303c999..6a2812109d 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -19,7 +19,7 @@ # # import logging -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import attr from signedjson.types import SigningKey @@ -83,7 +83,7 @@ class EventBuilder: room_version: RoomVersion # MSC4291 makes the room ID == the create event ID. This means the create event has no room_id. - room_id: Optional[str] + room_id: str | None type: str sender: str @@ -92,9 +92,9 @@ class EventBuilder: # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. - _state_key: Optional[str] = None - _redacts: Optional[str] = None - _origin_server_ts: Optional[int] = None + _state_key: str | None = None + _redacts: str | None = None + _origin_server_ts: int | None = None internal_metadata: EventInternalMetadata = attr.Factory( lambda: EventInternalMetadata({}) @@ -126,8 +126,8 @@ class EventBuilder: async def build( self, prev_event_ids: list[str], - auth_event_ids: Optional[list[str]], - depth: Optional[int] = None, + auth_event_ids: list[str] | None, + depth: int | None = None, ) -> EventBase: """Transform into a fully signed and hashed event @@ -205,8 +205,8 @@ class EventBuilder: format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. - prev_events: Union[StrCollection, list[tuple[str, dict[str, str]]]] - auth_events: Union[list[str], list[tuple[str, dict[str, str]]]] + prev_events: StrCollection | list[tuple[str, dict[str, str]]] + auth_events: list[str] | list[tuple[str, dict[str, str]]] if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) prev_events = await self._store.add_event_hashes(prev_event_ids) @@ -327,7 +327,7 @@ def create_local_event_from_event_dict( signing_key: SigningKey, room_version: RoomVersion, event_dict: JsonDict, - internal_metadata_dict: Optional[JsonDict] = None, + internal_metadata_dict: JsonDict | None = None, ) -> EventBase: """Takes a fully formed event dict, ensuring that fields like `origin_server_ts` have correct values for a locally produced event, diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index 39dd7ee2b3..d71d3f8feb 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -25,9 +25,7 @@ from typing import ( Awaitable, Callable, Iterable, - Optional, TypeVar, - Union, ) from typing_extensions import ParamSpec @@ -44,7 +42,7 @@ GET_USERS_FOR_STATES_CALLBACK = Callable[ [Iterable[UserPresenceState]], Awaitable[dict[str, set[UserPresenceState]]] ] # This must either return a set of strings or the constant PresenceRouter.ALL_USERS. -GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[set[str], str]]] +GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[set[str] | str]] logger = logging.getLogger(__name__) @@ -77,8 +75,8 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed def async_wrapper( - f: Optional[Callable[P, R]], - ) -> Optional[Callable[P, Awaitable[R]]]: + f: Callable[P, R] | None, + ) -> Callable[P, Awaitable[R]] | None: # f might be None if the callback isn't implemented by the module. In this # case we don't want to register a callback at all so we return None. if f is None: @@ -95,7 +93,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: return run # Register the hooks through the module API. - hooks: dict[str, Optional[Callable[..., Any]]] = { + hooks: dict[str, Callable[..., Any] | None] = { hook: async_wrapper(getattr(presence_router, hook, None)) for hook in presence_router_methods } @@ -118,8 +116,8 @@ class PresenceRouter: def register_presence_router_callbacks( self, - get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, - get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, + get_users_for_states: GET_USERS_FOR_STATES_CALLBACK | None = None, + get_interested_users: GET_INTERESTED_USERS_CALLBACK | None = None, ) -> None: # PresenceRouter modules are required to implement both of these methods # or neither of them as they are assumed to act in a complementary manner @@ -191,7 +189,7 @@ class PresenceRouter: return users_for_states - async def get_interested_users(self, user_id: str) -> Union[set[str], str]: + async def get_interested_users(self, user_id: str) -> set[str] | str: """ Retrieve a list of users that `user_id` is interested in receiving the presence of. This will be in addition to those they share a room with. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 764d31ee66..d7a987d52f 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -51,7 +51,7 @@ class UnpersistedEventContextBase(ABC): def __init__(self, storage_controller: "StorageControllers"): self._storage: "StorageControllers" = storage_controller - self.app_service: Optional[ApplicationService] = None + self.app_service: ApplicationService | None = None @abstractmethod async def persist( @@ -134,20 +134,20 @@ class EventContext(UnpersistedEventContextBase): _storage: "StorageControllers" state_group_deltas: dict[tuple[int, int], StateMap[str]] - rejected: Optional[str] = None - _state_group: Optional[int] = None - state_group_before_event: Optional[int] = None - _state_delta_due_to_event: Optional[StateMap[str]] = None - app_service: Optional[ApplicationService] = None + rejected: str | None = None + _state_group: int | None = None + state_group_before_event: int | None = None + _state_delta_due_to_event: StateMap[str] | None = None + app_service: ApplicationService | None = None partial_state: bool = False @staticmethod def with_state( storage: "StorageControllers", - state_group: Optional[int], - state_group_before_event: Optional[int], - state_delta_due_to_event: Optional[StateMap[str]], + state_group: int | None, + state_group_before_event: int | None, + state_delta_due_to_event: StateMap[str] | None, partial_state: bool, state_group_deltas: dict[tuple[int, int], StateMap[str]], ) -> "EventContext": @@ -227,7 +227,7 @@ class EventContext(UnpersistedEventContextBase): return context @property - def state_group(self) -> Optional[int]: + def state_group(self) -> int | None: """The ID of the state group for this event. Note that state events are persisted with a state group which includes the new @@ -354,13 +354,13 @@ class UnpersistedEventContext(UnpersistedEventContextBase): """ _storage: "StorageControllers" - state_group_before_event: Optional[int] - state_group_after_event: Optional[int] - state_delta_due_to_event: Optional[StateMap[str]] - prev_group_for_state_group_before_event: Optional[int] - delta_ids_to_state_group_before_event: Optional[StateMap[str]] + state_group_before_event: int | None + state_group_after_event: int | None + state_delta_due_to_event: StateMap[str] | None + prev_group_for_state_group_before_event: int | None + delta_ids_to_state_group_before_event: StateMap[str] | None partial_state: bool - state_map_before_event: Optional[StateMap[str]] = None + state_map_before_event: StateMap[str] | None = None @classmethod async def batch_persist_unpersisted_contexts( @@ -511,7 +511,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase): def _encode_state_group_delta( state_group_delta: dict[tuple[int, int], StateMap[str]], -) -> list[tuple[int, int, Optional[list[tuple[str, str, str]]]]]: +) -> list[tuple[int, int, list[tuple[str, str, str]] | None]]: if not state_group_delta: return [] @@ -538,8 +538,8 @@ def _decode_state_group_delta( def _encode_state_dict( - state_dict: Optional[StateMap[str]], -) -> Optional[list[tuple[str, str, str]]]: + state_dict: StateMap[str] | None, +) -> list[tuple[str, str, str]] | None: """Since dicts of (type, state_key) -> event_id cannot be serialized in JSON we need to convert them to a form that can. """ @@ -550,8 +550,8 @@ def _encode_state_dict( def _decode_state_dict( - input: Optional[list[tuple[str, str, str]]], -) -> Optional[StateMap[str]]: + input: list[tuple[str, str, str]] | None, +) -> StateMap[str] | None: """Decodes a state dict encoded using `_encode_state_dict` above""" if input is None: return None diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9fa251abd8..b79a68f589 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -30,8 +30,6 @@ from typing import ( Mapping, Match, MutableMapping, - Optional, - Union, ) import attr @@ -415,9 +413,9 @@ class SerializeEventConfig: event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 # The entity that requested the event. This is used to determine whether to include # the transaction_id in the unsigned section of the event. - requester: Optional[Requester] = None + requester: Requester | None = None # List of event fields to include. If empty, all fields will be returned. - only_event_fields: Optional[list[str]] = None + only_event_fields: list[str] | None = None # Some events can have stripped room state stored in the `unsigned` field. # This is required for invite and knock functionality. If this option is # False, that state will be removed from the event before it is returned. @@ -439,7 +437,7 @@ def make_config_for_admin(existing: SerializeEventConfig) -> SerializeEventConfi def serialize_event( - e: Union[JsonDict, EventBase], + e: JsonDict | EventBase, time_now_ms: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, @@ -480,7 +478,7 @@ def serialize_event( # If we have a txn_id saved in the internal_metadata, we should include it in the # unsigned section of the event if it was sent by the same session as the one # requesting the event. - txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None) + txn_id: str | None = getattr(e.internal_metadata, "txn_id", None) if ( txn_id is not None and config.requester is not None @@ -490,7 +488,7 @@ def serialize_event( # this includes old events as well as those created by appservice, guests, # or with tokens minted with the admin API. For those events, fallback # to using the access token instead. - event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None) + event_device_id: str | None = getattr(e.internal_metadata, "device_id", None) if event_device_id is not None: if event_device_id == config.requester.device_id: d["unsigned"]["transaction_id"] = txn_id @@ -504,9 +502,7 @@ def serialize_event( # # For guests and appservice users, we can't check the access token ID # so assume it is the same session. - event_token_id: Optional[int] = getattr( - e.internal_metadata, "token_id", None - ) + event_token_id: int | None = getattr(e.internal_metadata, "token_id", None) if ( ( event_token_id is not None @@ -577,11 +573,11 @@ class EventClientSerializer: async def serialize_event( self, - event: Union[JsonDict, EventBase], + event: JsonDict | EventBase, time_now: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, - bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None, + bundle_aggregations: dict[str, "BundledAggregations"] | None = None, ) -> JsonDict: """Serializes a single event. @@ -712,11 +708,11 @@ class EventClientSerializer: @trace async def serialize_events( self, - events: Collection[Union[JsonDict, EventBase]], + events: Collection[JsonDict | EventBase], time_now: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, - bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None, + bundle_aggregations: dict[str, "BundledAggregations"] | None = None, ) -> list[JsonDict]: """Serializes multiple events. @@ -755,13 +751,13 @@ class EventClientSerializer: self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback) -_PowerLevel = Union[str, int] -PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] +_PowerLevel = str | int +PowerLevelsContent = Mapping[str, _PowerLevel | Mapping[str, _PowerLevel]] def copy_and_fixup_power_levels_contents( old_power_levels: PowerLevelsContent, -) -> dict[str, Union[int, dict[str, int]]]: +) -> dict[str, int | dict[str, int]]: """Copy the content of a power_levels event, unfreezing immutabledicts along the way. We accept as input power level values which are strings, provided they represent an @@ -777,7 +773,7 @@ def copy_and_fixup_power_levels_contents( if not isinstance(old_power_levels, collections.abc.Mapping): raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) - power_levels: dict[str, Union[int, dict[str, int]]] = {} + power_levels: dict[str, int | dict[str, int]] = {} for k, v in old_power_levels.items(): if isinstance(v, collections.abc.Mapping): @@ -901,7 +897,7 @@ def strip_event(event: EventBase) -> JsonDict: } -def parse_stripped_state_event(raw_stripped_event: Any) -> Optional[StrippedStateEvent]: +def parse_stripped_state_event(raw_stripped_event: Any) -> StrippedStateEvent | None: """ Given a raw value from an event's `unsigned` field, attempt to parse it into a `StrippedStateEvent`. diff --git a/synapse/events/validator.py b/synapse/events/validator.py index c2cecd0fcb..b27f8a942a 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -19,7 +19,7 @@ # # import collections.abc -from typing import Union, cast +from typing import cast import jsonschema from pydantic import Field, StrictBool, StrictStr @@ -177,7 +177,7 @@ class EventValidator: errcode=Codes.BAD_JSON, ) - def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None: + def validate_builder(self, event: EventBase | EventBuilder) -> None: """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the fields an event would have @@ -249,7 +249,7 @@ class EventValidator: if not isinstance(d[s], str): raise SynapseError(400, "'%s' not a string type" % (s,)) - def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None: + def _ensure_state_event(self, event: EventBase | EventBuilder) -> None: if not event.is_state(): raise SynapseError(400, "'%s' must be state events" % (event.type,)) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 13e445456a..04ba5b86db 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -20,7 +20,7 @@ # # import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Awaitable, Callable, Sequence from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -67,7 +67,7 @@ class FederationBase: # We need to define this lazily otherwise we get a cyclic dependency. # self._policy_handler = hs.get_room_policy_handler() - self._policy_handler: Optional[RoomPolicyHandler] = None + self._policy_handler: RoomPolicyHandler | None = None def _lazily_get_policy_handler(self) -> RoomPolicyHandler: """Lazily get the room policy handler. @@ -88,9 +88,8 @@ class FederationBase: self, room_version: RoomVersion, pdu: EventBase, - record_failure_callback: Optional[ - Callable[[EventBase, str], Awaitable[None]] - ] = None, + record_failure_callback: Callable[[EventBase, str], Awaitable[None]] + | None = None, ) -> EventBase: """Checks that event is correctly signed by the sending server. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index cb2fa59f54..4110a90ed6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -37,7 +37,6 @@ from typing import ( Optional, Sequence, TypeVar, - Union, ) import attr @@ -263,7 +262,7 @@ class FederationClient(FederationBase): user: UserID, destination: str, query: dict[str, dict[str, dict[str, int]]], - timeout: Optional[int], + timeout: int | None, ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. @@ -334,7 +333,7 @@ class FederationClient(FederationBase): @tag_args async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] - ) -> Optional[list[EventBase]]: + ) -> list[EventBase] | None: """Requests some more historic PDUs for the given room from the given destination server. @@ -381,8 +380,8 @@ class FederationClient(FederationBase): destination: str, event_id: str, room_version: RoomVersion, - timeout: Optional[int] = None, - ) -> Optional[EventBase]: + timeout: int | None = None, + ) -> EventBase | None: """Requests the PDU with given origin and ID from the remote home server. Does not have any caching or rate limiting! @@ -441,7 +440,7 @@ class FederationClient(FederationBase): @trace @tag_args async def get_pdu_policy_recommendation( - self, destination: str, pdu: EventBase, timeout: Optional[int] = None + self, destination: str, pdu: EventBase, timeout: int | None = None ) -> str: """Requests that the destination server (typically a policy server) check the event and return its recommendation on how to handle the @@ -497,8 +496,8 @@ class FederationClient(FederationBase): @trace @tag_args async def ask_policy_server_to_sign_event( - self, destination: str, pdu: EventBase, timeout: Optional[int] = None - ) -> Optional[JsonDict]: + self, destination: str, pdu: EventBase, timeout: int | None = None + ) -> JsonDict | None: """Requests that the destination server (typically a policy server) sign the event as not spam. @@ -538,8 +537,8 @@ class FederationClient(FederationBase): destinations: Collection[str], event_id: str, room_version: RoomVersion, - timeout: Optional[int] = None, - ) -> Optional[PulledPduInfo]: + timeout: int | None = None, + ) -> PulledPduInfo | None: """Requests the PDU with given origin and ID from the remote home servers. @@ -832,10 +831,9 @@ class FederationClient(FederationBase): pdu: EventBase, origin: str, room_version: RoomVersion, - record_failure_callback: Optional[ - Callable[[EventBase, str], Awaitable[None]] - ] = None, - ) -> Optional[EventBase]: + record_failure_callback: Callable[[EventBase, str], Awaitable[None]] + | None = None, + ) -> EventBase | None: """Takes a PDU and checks its signatures and hashes. If the PDU fails its signature check then we check if we have it in the @@ -931,7 +929,7 @@ class FederationClient(FederationBase): description: str, destinations: Iterable[str], callback: Callable[[str], Awaitable[T]], - failover_errcodes: Optional[Container[str]] = None, + failover_errcodes: Container[str] | None = None, failover_on_unknown_endpoint: bool = False, ) -> T: """Try an operation on a series of servers, until it succeeds @@ -1046,7 +1044,7 @@ class FederationClient(FederationBase): user_id: str, membership: str, content: dict, - params: Optional[Mapping[str, Union[str, Iterable[str]]]], + params: Mapping[str, str | Iterable[str]] | None, ) -> tuple[str, EventBase, RoomVersion]: """ Creates an m.room.member event, with context, without participating in the room. @@ -1563,11 +1561,11 @@ class FederationClient(FederationBase): async def get_public_rooms( self, remote_server: str, - limit: Optional[int] = None, - since_token: Optional[str] = None, - search_filter: Optional[dict] = None, + limit: int | None = None, + since_token: str | None = None, + search_filter: dict | None = None, include_all_networks: bool = False, - third_party_instance_id: Optional[str] = None, + third_party_instance_id: str | None = None, ) -> JsonDict: """Get the list of public rooms from a remote homeserver @@ -1676,7 +1674,7 @@ class FederationClient(FederationBase): async def get_room_complexity( self, destination: str, room_id: str - ) -> Optional[JsonDict]: + ) -> JsonDict | None: """ Fetch the complexity of a remote room from another server. @@ -1987,10 +1985,10 @@ class FederationClient(FederationBase): max_timeout_ms: int, download_ratelimiter: Ratelimiter, ip_address: str, - ) -> Union[ - tuple[int, dict[bytes, list[bytes]], bytes], - tuple[int, dict[bytes, list[bytes]]], - ]: + ) -> ( + tuple[int, dict[bytes, list[bytes]], bytes] + | tuple[int, dict[bytes, list[bytes]]] + ): try: return await self.transport_layer.federation_download_media( destination, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6e14f4a049..34abac1cec 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,8 +28,6 @@ from typing import ( Callable, Collection, Mapping, - Optional, - Union, ) from prometheus_client import Counter, Gauge, Histogram @@ -176,13 +174,11 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache: ResponseCache[tuple[str, Optional[str]]] = ( - ResponseCache( - clock=hs.get_clock(), - name="state_resp", - server_name=self.server_name, - timeout_ms=30000, - ) + self._state_resp_cache: ResponseCache[tuple[str, str | None]] = ResponseCache( + clock=hs.get_clock(), + name="state_resp", + server_name=self.server_name, + timeout_ms=30000, ) self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache( clock=hs.get_clock(), @@ -666,7 +662,7 @@ class FederationServer(FederationBase): async def on_pdu_request( self, origin: str, event_id: str - ) -> tuple[int, Union[JsonDict, str]]: + ) -> tuple[int, JsonDict | str]: pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: @@ -763,7 +759,7 @@ class FederationServer(FederationBase): prev_state_ids = await context.get_prev_state_ids() state_event_ids: Collection[str] - servers_in_room: Optional[Collection[str]] + servers_in_room: Collection[str] | None if caller_supports_partial_state: summary = await self.store.get_room_summary(room_id) state_event_ids = _get_event_ids_for_partial_state_join( @@ -1126,7 +1122,7 @@ class FederationServer(FederationBase): return {"events": serialize_and_filter_pdus(missing_events, time_now)} - async def on_openid_userinfo(self, token: str) -> Optional[str]: + async def on_openid_userinfo(self, token: str) -> str | None: ts_now_ms = self._clock.time_msec() return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) @@ -1205,7 +1201,7 @@ class FederationServer(FederationBase): async def _get_next_nonspam_staged_event_for_room( self, room_id: str, room_version: RoomVersion - ) -> Optional[tuple[str, EventBase]]: + ) -> tuple[str, EventBase] | None: """Fetch the first non-spam event from staging queue. Args: @@ -1246,8 +1242,8 @@ class FederationServer(FederationBase): room_id: str, room_version: RoomVersion, lock: Lock, - latest_origin: Optional[str] = None, - latest_event: Optional[EventBase] = None, + latest_origin: str | None = None, + latest_event: EventBase | None = None, ) -> None: """Process events in the staging area for the given room. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 5628130429..dca13191fc 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -27,7 +27,6 @@ These actions are mostly only used by the :py:mod:`.replication` module. """ import logging -from typing import Optional from synapse.federation.units import Transaction from synapse.storage.databases.main import DataStore @@ -44,7 +43,7 @@ class TransactionActions: async def have_responded( self, origin: str, transaction: Transaction - ) -> Optional[tuple[int, JsonDict]]: + ) -> tuple[int, JsonDict] | None: """Have we already responded to a transaction with the same id and origin? diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 80f31798e8..cf70e10a58 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -42,7 +42,6 @@ from typing import ( TYPE_CHECKING, Hashable, Iterable, - Optional, Sized, ) @@ -217,7 +216,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): destination: str, edu_type: str, content: JsonDict, - key: Optional[Hashable] = None, + key: Hashable | None = None, ) -> None: """As per FederationSender""" if self.is_mine_server_name(destination): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 229ae647c0..0bd97c25df 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -138,7 +138,6 @@ from typing import ( Hashable, Iterable, Literal, - Optional, ) import attr @@ -266,7 +265,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): destination: str, edu_type: str, content: JsonDict, - key: Optional[Hashable] = None, + key: Hashable | None = None, ) -> None: """Construct an Edu object, and queue it for sending @@ -410,7 +409,7 @@ class FederationSender(AbstractFederationSender): self.is_mine_id = hs.is_mine_id self.is_mine_server_name = hs.is_mine_server_name - self._presence_router: Optional["PresenceRouter"] = None + self._presence_router: "PresenceRouter" | None = None self._transaction_manager = TransactionManager(hs) self._instance_name = hs.get_instance_name() @@ -481,7 +480,7 @@ class FederationSender(AbstractFederationSender): def _get_per_destination_queue( self, destination: str - ) -> Optional[PerDestinationQueue]: + ) -> PerDestinationQueue | None: """Get or create a PerDestinationQueue for the given destination Args: @@ -605,7 +604,7 @@ class FederationSender(AbstractFederationSender): ) return - destinations: Optional[Collection[str]] = None + destinations: Collection[str] | None = None if not event.prev_event_ids(): # If there are no prev event IDs then the state is empty # and so no remote servers in the room @@ -1010,7 +1009,7 @@ class FederationSender(AbstractFederationSender): destination: str, edu_type: str, content: JsonDict, - key: Optional[Hashable] = None, + key: Hashable | None = None, ) -> None: """Construct an Edu object, and queue it for sending @@ -1038,7 +1037,7 @@ class FederationSender(AbstractFederationSender): self.send_edu(edu, key) - def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None: + def send_edu(self, edu: Edu, key: Hashable | None) -> None: """Queue an EDU for sending Args: @@ -1134,7 +1133,7 @@ class FederationSender(AbstractFederationSender): In order to reduce load spikes, adds a delay between each destination. """ - last_processed: Optional[str] = None + last_processed: str | None = None while not self._is_shutdown: destinations_to_wake = ( diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index ecf4789d76..4a1b84aed7 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -23,7 +23,7 @@ import datetime import logging from collections import OrderedDict from types import TracebackType -from typing import TYPE_CHECKING, Hashable, Iterable, Optional +from typing import TYPE_CHECKING, Hashable, Iterable import attr from prometheus_client import Counter @@ -121,7 +121,7 @@ class PerDestinationQueue: self._destination = destination self.transmission_loop_running = False self._transmission_loop_enabled = True - self.active_transmission_loop: Optional[defer.Deferred] = None + self.active_transmission_loop: defer.Deferred | None = None # Flag to signal to any running transmission loop that there is new data # queued up to be sent. @@ -142,7 +142,7 @@ class PerDestinationQueue: # Cache of the last successfully-transmitted stream ordering for this # destination (we are the only updater so this is safe) - self._last_successful_stream_ordering: Optional[int] = None + self._last_successful_stream_ordering: int | None = None # a queue of pending PDUs self._pending_pdus: list[EventBase] = [] @@ -742,9 +742,9 @@ class _TransactionQueueManager: queue: PerDestinationQueue - _device_stream_id: Optional[int] = None - _device_list_id: Optional[int] = None - _last_stream_ordering: Optional[int] = None + _device_stream_id: int | None = None + _device_list_id: int | None = None + _last_stream_ordering: int | None = None _pdus: list[EventBase] = attr.Factory(list) async def __aenter__(self) -> tuple[list[EventBase], list[Edu]]: @@ -845,9 +845,9 @@ class _TransactionQueueManager: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: if exc_type is not None: # Failed to send transaction, so we bail out. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index ee15b4804e..35d3c30c69 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -31,8 +31,6 @@ from typing import ( Generator, Iterable, Mapping, - Optional, - Union, ) import attr @@ -122,7 +120,7 @@ class TransportLayerClient: ) async def get_event( - self, destination: str, event_id: str, timeout: Optional[int] = None + self, destination: str, event_id: str, timeout: int | None = None ) -> JsonDict: """Requests the pdu with give id and origin from the given server. @@ -144,7 +142,7 @@ class TransportLayerClient: ) async def get_policy_recommendation_for_pdu( - self, destination: str, event: EventBase, timeout: Optional[int] = None + self, destination: str, event: EventBase, timeout: int | None = None ) -> JsonDict: """Requests the policy recommendation for the given pdu from the given policy server. @@ -171,7 +169,7 @@ class TransportLayerClient: ) async def ask_policy_server_to_sign_event( - self, destination: str, event: EventBase, timeout: Optional[int] = None + self, destination: str, event: EventBase, timeout: int | None = None ) -> JsonDict: """Requests that the destination server (typically a policy server) sign the event as not spam. @@ -198,7 +196,7 @@ class TransportLayerClient: async def backfill( self, destination: str, room_id: str, event_tuples: Collection[str], limit: int - ) -> Optional[Union[JsonDict, list]]: + ) -> JsonDict | list | None: """Requests `limit` previous PDUs in a given context before list of PDUs. @@ -235,7 +233,7 @@ class TransportLayerClient: async def timestamp_to_event( self, destination: str, room_id: str, timestamp: int, direction: Direction - ) -> Union[JsonDict, list]: + ) -> JsonDict | list: """ Calls a remote federating server at `destination` asking for their closest event to the given timestamp in the given direction. @@ -270,7 +268,7 @@ class TransportLayerClient: async def send_transaction( self, transaction: Transaction, - json_data_callback: Optional[Callable[[], JsonDict]] = None, + json_data_callback: Callable[[], JsonDict] | None = None, ) -> JsonDict: """Sends the given Transaction to its destination @@ -343,7 +341,7 @@ class TransportLayerClient: room_id: str, user_id: str, membership: str, - params: Optional[Mapping[str, Union[str, Iterable[str]]]], + params: Mapping[str, str | Iterable[str]] | None, ) -> JsonDict: """Asks a remote server to build and sign us a membership event @@ -528,11 +526,11 @@ class TransportLayerClient: async def get_public_rooms( self, remote_server: str, - limit: Optional[int] = None, - since_token: Optional[str] = None, - search_filter: Optional[dict] = None, + limit: int | None = None, + since_token: str | None = None, + search_filter: dict | None = None, include_all_networks: bool = False, - third_party_instance_id: Optional[str] = None, + third_party_instance_id: str | None = None, ) -> JsonDict: """Get the list of public rooms from a remote homeserver @@ -567,7 +565,7 @@ class TransportLayerClient: ) raise else: - args: dict[str, Union[str, Iterable[str]]] = { + args: dict[str, str | Iterable[str]] = { "include_all_networks": "true" if include_all_networks else "false" } if third_party_instance_id: @@ -694,7 +692,7 @@ class TransportLayerClient: user: UserID, destination: str, query_content: JsonDict, - timeout: Optional[int], + timeout: int | None, ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. @@ -740,7 +738,7 @@ class TransportLayerClient: user: UserID, destination: str, query_content: JsonDict, - timeout: Optional[int], + timeout: int | None, ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. @@ -997,13 +995,13 @@ class SendJoinResponse: event_dict: JsonDict # The parsed join event from the /send_join response. This will be None if # "event" is not included in the response. - event: Optional[EventBase] = None + event: EventBase | None = None # The room state is incomplete members_omitted: bool = False # List of servers in the room - servers_in_room: Optional[list[str]] = None + servers_in_room: list[str] | None = None @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index d5f05f7290..6d92d00523 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -20,7 +20,7 @@ # # import logging -from typing import TYPE_CHECKING, Iterable, Literal, Optional +from typing import TYPE_CHECKING, Iterable, Literal from synapse.api.errors import FederationDeniedError, SynapseError from synapse.federation.transport.server._base import ( @@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs: "HomeServer", servlet_groups: Optional[list[str]] = None): + def __init__(self, hs: "HomeServer", servlet_groups: list[str] | None = None): """Initialize the TransportLayerServer Will by default register all servlets. For custom behaviour, pass in @@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet): if not self.allow_access: raise FederationDeniedError(origin) - limit: Optional[int] = parse_integer_from_args(query, "limit", 0) + limit: int | None = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) include_all_networks = parse_boolean_from_args( query, "include_all_networks", default=False @@ -170,7 +170,7 @@ class PublicRoomList(BaseFederationServlet): if not self.allow_access: raise FederationDeniedError(origin) - limit: Optional[int] = int(content.get("limit", 100)) + limit: int | None = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -240,7 +240,7 @@ class OpenIdUserInfo(BaseFederationServlet): async def on_GET( self, - origin: Optional[str], + origin: str | None, content: Literal[None], query: dict[bytes, list[bytes]], ) -> tuple[int, JsonDict]: @@ -281,7 +281,7 @@ def register_servlets( resource: HttpServer, authenticator: Authenticator, ratelimiter: FederationRateLimiter, - servlet_groups: Optional[Iterable[str]] = None, + servlet_groups: Iterable[str] | None = None, ) -> None: """Initialize and register servlet classes. diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 146cbebb27..52c0c96a3f 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -24,7 +24,7 @@ import logging import re import time from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, cast from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_V1_PREFIX @@ -77,7 +77,7 @@ class Authenticator: # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request( - self, request: SynapseRequest, content: Optional[JsonDict] + self, request: SynapseRequest, content: JsonDict | None ) -> str: now = self._clock.time_msec() json_request: JsonDict = { @@ -165,7 +165,7 @@ class Authenticator: logger.exception("Error resetting retry timings on %s", origin) -def _parse_auth_header(header_bytes: bytes) -> tuple[str, str, str, Optional[str]]: +def _parse_auth_header(header_bytes: bytes) -> tuple[str, str, str, str | None]: """Parse an X-Matrix auth header Args: @@ -252,7 +252,7 @@ class BaseFederationServlet: components as specified in the path match regexp. Returns: - Optional[tuple[int, object]]: either (response code, response object) to + tuple[int, object] | None: either (response code, response object) to return a JSON response, or None if the request has already been handled. Raises: @@ -289,7 +289,7 @@ class BaseFederationServlet: @functools.wraps(func) async def new_func( request: SynapseRequest, *args: Any, **kwargs: str - ) -> Optional[tuple[int, Any]]: + ) -> tuple[int, Any] | None: """A callback which can be passed to HttpServer.RegisterPaths Args: @@ -309,7 +309,7 @@ class BaseFederationServlet: try: with start_active_span("authenticate_request"): - origin: Optional[str] = await authenticator.authenticate_request( + origin: str | None = await authenticator.authenticate_request( request, content ) except NoAuthenticationError: diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 54c7dac1b7..a7c297c0b7 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -24,9 +24,7 @@ from typing import ( TYPE_CHECKING, Literal, Mapping, - Optional, Sequence, - Union, ) from synapse.api.constants import Direction, EduTypes @@ -156,7 +154,7 @@ class FederationEventServlet(BaseFederationServerServlet): content: Literal[None], query: dict[bytes, list[bytes]], event_id: str, - ) -> tuple[int, Union[JsonDict, str]]: + ) -> tuple[int, JsonDict | str]: return await self.handler.on_pdu_request(origin, event_id) @@ -642,7 +640,7 @@ class On3pidBindServlet(BaseFederationServerServlet): REQUIRE_AUTH = False async def on_POST( - self, origin: Optional[str], content: JsonDict, query: dict[bytes, list[bytes]] + self, origin: str | None, content: JsonDict, query: dict[bytes, list[bytes]] ) -> tuple[int, JsonDict]: if "invites" in content: last_exception = None @@ -676,7 +674,7 @@ class FederationVersionServlet(BaseFederationServlet): async def on_GET( self, - origin: Optional[str], + origin: str | None, content: Literal[None], query: dict[bytes, list[bytes]], ) -> tuple[int, JsonDict]: @@ -812,7 +810,7 @@ class FederationMediaDownloadServlet(BaseFederationServerServlet): async def on_GET( self, - origin: Optional[str], + origin: str | None, content: Literal[None], request: SynapseRequest, media_id: str, @@ -852,7 +850,7 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet): async def on_GET( self, - origin: Optional[str], + origin: str | None, content: Literal[None], request: SynapseRequest, media_id: str, diff --git a/synapse/federation/units.py b/synapse/federation/units.py index bff45bc2a9..547db9a394 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -24,7 +24,7 @@ server protocol. """ import logging -from typing import Optional, Sequence +from typing import Sequence import attr @@ -70,7 +70,7 @@ class Edu: getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" -def _none_to_list(edus: Optional[list[JsonDict]]) -> list[JsonDict]: +def _none_to_list(edus: list[JsonDict] | None) -> list[JsonDict]: if edus is None: return [] return edus @@ -128,6 +128,6 @@ def filter_pdus_for_valid_depth(pdus: Sequence[JsonDict]) -> list[JsonDict]: def serialize_and_filter_pdus( - pdus: Sequence[EventBase], time_now: Optional[int] = None + pdus: Sequence[EventBase], time_now: int | None = None ) -> list[JsonDict]: return filter_pdus_for_valid_depth([pdu.get_pdu_json(time_now) for pdu in pdus]) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 4492612859..c6168377ee 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -21,7 +21,7 @@ # import logging import random -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Awaitable, Callable from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( @@ -40,9 +40,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[ - [str, Optional[str], str, JsonDict], Awaitable -] +ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[[str, str | None, str, JsonDict], Awaitable] class AccountDataHandler: @@ -72,7 +70,7 @@ class AccountDataHandler: ] = [] def register_module_callbacks( - self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None + self, on_account_data_updated: ON_ACCOUNT_DATA_UPDATED_CALLBACK | None = None ) -> None: """Register callbacks from modules.""" if on_account_data_updated is not None: @@ -81,7 +79,7 @@ class AccountDataHandler: async def _notify_modules( self, user_id: str, - room_id: Optional[str], + room_id: str | None, account_data_type: str, content: JsonDict, ) -> None: @@ -143,7 +141,7 @@ class AccountDataHandler: async def remove_account_data_for_room( self, user_id: str, room_id: str, account_data_type: str - ) -> Optional[int]: + ) -> int | None: """ Deletes the room account data for the given user and account data type. @@ -219,7 +217,7 @@ class AccountDataHandler: async def remove_account_data_for_user( self, user_id: str, account_data_type: str - ) -> Optional[int]: + ) -> int | None: """Removes a piece of global account_data for a user. Args: @@ -324,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): limit: int, room_ids: StrCollection, is_guest: bool, - explicit_room_id: Optional[str] = None, + explicit_room_id: str | None = None, ) -> tuple[list[JsonDict], int]: user_id = user.to_string() last_stream_id = from_key diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index a805de1f35..bc50efa1a7 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -21,7 +21,7 @@ import email.mime.multipart import email.utils import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -108,8 +108,8 @@ class AccountValidityHandler: async def on_user_login( self, user_id: str, - auth_provider_type: Optional[str], - auth_provider_id: Optional[str], + auth_provider_type: str | None, + auth_provider_id: str | None, ) -> None: """Tell third-party modules about a user logins. @@ -326,9 +326,9 @@ class AccountValidityHandler: async def renew_account_for_user( self, user_id: str, - expiration_ts: Optional[int] = None, + expiration_ts: int | None = None, email_sent: bool = False, - renewal_token: Optional[str] = None, + renewal_token: str | None = None, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 3faaa4d2b3..c979752f7f 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -25,7 +25,6 @@ from typing import ( TYPE_CHECKING, Any, Mapping, - Optional, Sequence, ) @@ -71,7 +70,7 @@ class AdminHandler: self.hs = hs - async def get_redact_task(self, redact_id: str) -> Optional[ScheduledTask]: + async def get_redact_task(self, redact_id: str) -> ScheduledTask | None: """Get the current status of an active redaction process Args: @@ -99,11 +98,9 @@ class AdminHandler: return ret - async def get_user(self, user: UserID) -> Optional[JsonMapping]: + async def get_user(self, user: UserID) -> JsonMapping | None: """Function to get user details""" - user_info: Optional[UserInfo] = await self._store.get_user_by_id( - user.to_string() - ) + user_info: UserInfo | None = await self._store.get_user_by_id(user.to_string()) if user_info is None: return None @@ -355,8 +352,8 @@ class AdminHandler: rooms: list, requester: JsonMapping, use_admin: bool, - reason: Optional[str], - limit: Optional[int], + reason: str | None, + limit: int | None, ) -> str: """ Start a task redacting the events of the given user in the given rooms @@ -408,7 +405,7 @@ class AdminHandler: async def _redact_all_events( self, task: ScheduledTask - ) -> tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]: + ) -> tuple[TaskStatus, Mapping[str, Any] | None, str | None]: """ Task to redact all of a users events in the given rooms, tracking which, if any, events whose redaction failed diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5240178d80..c91d2adbe1 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -24,8 +24,6 @@ from typing import ( Collection, Iterable, Mapping, - Optional, - Union, ) from prometheus_client import Counter @@ -240,8 +238,8 @@ class ApplicationServicesHandler: def notify_interested_services_ephemeral( self, stream_key: StreamKeyType, - new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], - users: Collection[Union[str, UserID]], + new_token: int | RoomStreamToken | MultiWriterStreamToken, + users: Collection[str | UserID], ) -> None: """ This is called by the notifier in the background when an ephemeral event is handled @@ -340,8 +338,8 @@ class ApplicationServicesHandler: self, services: list[ApplicationService], stream_key: StreamKeyType, - new_token: Union[int, MultiWriterStreamToken], - users: Collection[Union[str, UserID]], + new_token: int | MultiWriterStreamToken, + users: Collection[str | UserID], ) -> None: logger.debug("Checking interested services for %s", stream_key) with Measure( @@ -498,8 +496,8 @@ class ApplicationServicesHandler: async def _handle_presence( self, service: ApplicationService, - users: Collection[Union[str, UserID]], - new_token: Optional[int], + users: Collection[str | UserID], + new_token: int | None, ) -> list[JsonMapping]: """ Return the latest presence updates that the given application service should receive. @@ -559,7 +557,7 @@ class ApplicationServicesHandler: self, service: ApplicationService, new_token: int, - users: Collection[Union[str, UserID]], + users: Collection[str | UserID], ) -> list[JsonDict]: """ Given an application service, determine which events it should receive @@ -733,7 +731,7 @@ class ApplicationServicesHandler: async def query_room_alias_exists( self, room_alias: RoomAlias - ) -> Optional[RoomAliasMapping]: + ) -> RoomAliasMapping | None: """Check if an application service knows this room alias exists. Args: @@ -782,7 +780,7 @@ class ApplicationServicesHandler: return ret async def get_3pe_protocols( - self, only_protocol: Optional[str] = None + self, only_protocol: str | None = None ) -> dict[str, JsonDict]: services = self.store.get_app_services() protocols: dict[str, list[JsonDict]] = {} @@ -935,7 +933,7 @@ class ApplicationServicesHandler: return claimed_keys, missing async def query_keys( - self, query: Mapping[str, Optional[list[str]]] + self, query: Mapping[str, list[str] | None] ) -> dict[str, dict[str, dict[str, JsonDict]]]: """Query application services for device keys. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ed796cfe06..d9355d33da 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -33,8 +33,6 @@ from typing import ( Callable, Iterable, Mapping, - Optional, - Union, cast, ) @@ -289,7 +287,7 @@ class AuthHandler: request_body: dict[str, Any], description: str, can_skip_ui_auth: bool = False, - ) -> tuple[dict, Optional[str]]: + ) -> tuple[dict, str | None]: """ Checks that the user is who they claim to be, via a UI auth. @@ -440,7 +438,7 @@ class AuthHandler: request: SynapseRequest, clientdict: dict[str, Any], description: str, - get_new_session_data: Optional[Callable[[], JsonDict]] = None, + get_new_session_data: Callable[[], JsonDict] | None = None, ) -> tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration @@ -487,7 +485,7 @@ class AuthHandler: all the stages in any of the permitted flows. """ - sid: Optional[str] = None + sid: str | None = None authdict = clientdict.pop("auth", {}) if "session" in authdict: sid = authdict["session"] @@ -637,7 +635,7 @@ class AuthHandler: authdict["session"], stagetype, result ) - def get_session_id(self, clientdict: dict[str, Any]) -> Optional[str]: + def get_session_id(self, clientdict: dict[str, Any]) -> str | None: """ Gets the session ID for a client given the client dictionary @@ -673,7 +671,7 @@ class AuthHandler: raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) async def get_session_data( - self, session_id: str, key: str, default: Optional[Any] = None + self, session_id: str, key: str, default: Any | None = None ) -> Any: """ Retrieve data stored with set_session_data @@ -699,7 +697,7 @@ class AuthHandler: async def _check_auth_dict( self, authdict: dict[str, Any], clientip: str - ) -> Union[dict[str, Any], str]: + ) -> dict[str, Any] | str: """Attempt to validate the auth dict provided by a client Args: @@ -774,9 +772,9 @@ class AuthHandler: async def refresh_token( self, refresh_token: str, - access_token_valid_until_ms: Optional[int], - refresh_token_valid_until_ms: Optional[int], - ) -> tuple[str, str, Optional[int]]: + access_token_valid_until_ms: int | None, + refresh_token_valid_until_ms: int | None, + ) -> tuple[str, str, int | None]: """ Consumes a refresh token and generate both a new access token and a new refresh token from it. @@ -909,8 +907,8 @@ class AuthHandler: self, user_id: str, duration_ms: int = (2 * 60 * 1000), - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + auth_provider_id: str | None = None, + auth_provider_session_id: str | None = None, ) -> str: login_token = self.generate_login_token() now = self._clock.time_msec() @@ -928,8 +926,8 @@ class AuthHandler: self, user_id: str, device_id: str, - expiry_ts: Optional[int], - ultimate_session_expiry_ts: Optional[int], + expiry_ts: int | None, + ultimate_session_expiry_ts: int | None, ) -> tuple[str, int]: """ Creates a new refresh token for the user with the given user ID. @@ -961,11 +959,11 @@ class AuthHandler: async def create_access_token_for_user_id( self, user_id: str, - device_id: Optional[str], - valid_until_ms: Optional[int], - puppets_user_id: Optional[str] = None, + device_id: str | None, + valid_until_ms: int | None, + puppets_user_id: str | None = None, is_appservice_ghost: bool = False, - refresh_token_id: Optional[int] = None, + refresh_token_id: int | None = None, ) -> str: """ Creates a new access token for the user with the given user ID. @@ -1034,7 +1032,7 @@ class AuthHandler: return access_token - async def check_user_exists(self, user_id: str) -> Optional[str]: + async def check_user_exists(self, user_id: str) -> str | None: """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. @@ -1061,9 +1059,7 @@ class AuthHandler: """ return await self.store.is_user_approved(user_id) - async def _find_user_id_and_pwd_hash( - self, user_id: str - ) -> Optional[tuple[str, str]]: + async def _find_user_id_and_pwd_hash(self, user_id: str) -> tuple[str, str] | None: """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. @@ -1141,7 +1137,7 @@ class AuthHandler: login_submission: dict[str, Any], ratelimit: bool = False, is_reauth: bool = False, - ) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: + ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate auth types which don't @@ -1297,7 +1293,7 @@ class AuthHandler: self, username: str, login_submission: dict[str, Any], - ) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: + ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None]: """Helper for validate_login Handles login, once we've mapped 3pids onto userids @@ -1386,7 +1382,7 @@ class AuthHandler: async def check_password_provider_3pid( self, medium: str, address: str, password: str - ) -> tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]: + ) -> tuple[str | None, Callable[["LoginResponse"], Awaitable[None]] | None]: """Check if a password provider is able to validate a thirdparty login Args: @@ -1413,7 +1409,7 @@ class AuthHandler: # if result is None then return (None, None) return None, None - async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: + async def _check_local_password(self, user_id: str, password: str) -> str | None: """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are @@ -1528,8 +1524,8 @@ class AuthHandler: async def delete_access_tokens_for_user( self, user_id: str, - except_token_id: Optional[int] = None, - device_id: Optional[str] = None, + except_token_id: int | None = None, + device_id: str | None = None, ) -> None: """Invalidate access tokens belonging to a user @@ -1700,9 +1696,7 @@ class AuthHandler: return await defer_to_thread(self.hs.get_reactor(), _do_hash) - async def validate_hash( - self, password: str, stored_hash: Union[bytes, str] - ) -> bool: + async def validate_hash(self, password: str, stored_hash: bytes | str) -> bool: """Validates that self.hash(password) == stored_hash. Args: @@ -1799,9 +1793,9 @@ class AuthHandler: auth_provider_id: str, request: Request, client_redirect_url: str, - extra_attributes: Optional[JsonDict] = None, + extra_attributes: JsonDict | None = None, new_user: bool = False, - auth_provider_session_id: Optional[str] = None, + auth_provider_session_id: str | None = None, ) -> None: """Having figured out a mxid for this user, complete the HTTP request @@ -1960,7 +1954,7 @@ def load_single_legacy_password_auth_provider( # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed - def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + def async_wrapper(f: Callable | None) -> Callable[..., Awaitable] | None: # f might be None if the callback isn't implemented by the module. In this # case we don't want to register a callback at all so we return None. if f is None: @@ -1973,7 +1967,7 @@ def load_single_legacy_password_auth_provider( async def wrapped_check_password( username: str, login_type: str, login_dict: JsonDict - ) -> Optional[tuple[str, Optional[Callable]]]: + ) -> tuple[str, Callable | None] | None: # We've already made sure f is not None above, but mypy doesn't do well # across function boundaries so we need to tell it f is definitely not # None. @@ -1992,12 +1986,12 @@ def load_single_legacy_password_auth_provider( return wrapped_check_password # We need to wrap check_auth as in the old form it could return - # just a str, but now it must return Optional[tuple[str, Optional[Callable]] + # just a str, but now it must return tuple[str, Callable | None] | None if f.__name__ == "check_auth": async def wrapped_check_auth( username: str, login_type: str, login_dict: JsonDict - ) -> Optional[tuple[str, Optional[Callable]]]: + ) -> tuple[str, Callable | None] | None: # We've already made sure f is not None above, but mypy doesn't do well # across function boundaries so we need to tell it f is definitely not # None. @@ -2013,12 +2007,12 @@ def load_single_legacy_password_auth_provider( return wrapped_check_auth # We need to wrap check_3pid_auth as in the old form it could return - # just a str, but now it must return Optional[tuple[str, Optional[Callable]] + # just a str, but now it must return tuple[str, Callable | None] | None if f.__name__ == "check_3pid_auth": async def wrapped_check_3pid_auth( medium: str, address: str, password: str - ) -> Optional[tuple[str, Optional[Callable]]]: + ) -> tuple[str, Callable | None] | None: # We've already made sure f is not None above, but mypy doesn't do well # across function boundaries so we need to tell it f is definitely not # None. @@ -2044,10 +2038,10 @@ def load_single_legacy_password_auth_provider( # If the module has these methods implemented, then we pull them out # and register them as hooks. - check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper( + check_3pid_auth_hook: CHECK_3PID_AUTH_CALLBACK | None = async_wrapper( getattr(provider, "check_3pid_auth", None) ) - on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper( + on_logged_out_hook: ON_LOGGED_OUT_CALLBACK | None = async_wrapper( getattr(provider, "on_logged_out", None) ) @@ -2085,24 +2079,20 @@ def load_single_legacy_password_auth_provider( CHECK_3PID_AUTH_CALLBACK = Callable[ [str, str, str], - Awaitable[ - Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] - ], + Awaitable[tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None], ] -ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] +ON_LOGGED_OUT_CALLBACK = Callable[[str, str | None, str], Awaitable] CHECK_AUTH_CALLBACK = Callable[ [str, str, JsonDict], - Awaitable[ - Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] - ], + Awaitable[tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None], ] GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ [JsonDict, JsonDict], - Awaitable[Optional[str]], + Awaitable[str | None], ] GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ [JsonDict, JsonDict], - Awaitable[Optional[str]], + Awaitable[str | None], ] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] @@ -2133,18 +2123,15 @@ class PasswordAuthProvider: def register_password_auth_provider_callbacks( self, - check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, - on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, - is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None, - auth_checkers: Optional[ - dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK] - ] = None, - get_username_for_registration: Optional[ - GET_USERNAME_FOR_REGISTRATION_CALLBACK - ] = None, - get_displayname_for_registration: Optional[ - GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK - ] = None, + check_3pid_auth: CHECK_3PID_AUTH_CALLBACK | None = None, + on_logged_out: ON_LOGGED_OUT_CALLBACK | None = None, + is_3pid_allowed: IS_3PID_ALLOWED_CALLBACK | None = None, + auth_checkers: dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK] + | None = None, + get_username_for_registration: GET_USERNAME_FOR_REGISTRATION_CALLBACK + | None = None, + get_displayname_for_registration: GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + | None = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2214,7 +2201,7 @@ class PasswordAuthProvider: async def check_auth( self, username: str, login_type: str, login_dict: JsonDict - ) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: + ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None: """Check if the user has presented valid login credentials Args: @@ -2245,14 +2232,14 @@ class PasswordAuthProvider: continue if result is not None: - # Check that the callback returned a Tuple[str, Optional[Callable]] + # Check that the callback returned a tuple[str, Callable | None] # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks # result is always the right type, but as it is 3rd party code it might not be if not isinstance(result, tuple) or len(result) != 2: logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" - " Optional[tuple[str, Optional[Callable]]]", + " tuple[str, Callable | None] | None", callback, result, ) @@ -2265,24 +2252,24 @@ class PasswordAuthProvider: if not isinstance(str_result, str): logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" - " Optional[tuple[str, Optional[Callable]]]", + " tuple[str, Callable | None] | None", callback, result, ) continue - # the second should be Optional[Callable] + # the second should be Callable | None if callback_result is not None: if not callable(callback_result): logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" - " Optional[tuple[str, Optional[Callable]]]", + " tuple[str, Callable | None] | None", callback, result, ) continue - # The result is a (str, Optional[callback]) tuple so return the successful result + # The result is a (str, callback | None) tuple so return the successful result return result # If this point has been reached then none of the callbacks successfully authenticated @@ -2291,7 +2278,7 @@ class PasswordAuthProvider: async def check_3pid_auth( self, medium: str, address: str, password: str - ) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: + ) -> tuple[str, Callable[["LoginResponse"], Awaitable[None]] | None] | None: # This function is able to return a deferred that either # resolves None, meaning authentication failure, or upon # success, to a str (which is the user_id) or a tuple of @@ -2308,14 +2295,14 @@ class PasswordAuthProvider: continue if result is not None: - # Check that the callback returned a Tuple[str, Optional[Callable]] + # Check that the callback returned a tuple[str, Callable | None] # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks # result is always the right type, but as it is 3rd party code it might not be if not isinstance(result, tuple) or len(result) != 2: logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" - " Optional[tuple[str, Optional[Callable]]]", + " tuple[str, Callable | None] | None", callback, result, ) @@ -2328,24 +2315,24 @@ class PasswordAuthProvider: if not isinstance(str_result, str): logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" - " Optional[tuple[str, Optional[Callable]]]", + " tuple[str, Callable | None] | None", callback, result, ) continue - # the second should be Optional[Callable] + # the second should be Callable | None if callback_result is not None: if not callable(callback_result): logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" - " Optional[tuple[str, Optional[Callable]]]", + " tuple[str, Callable | None] | None", callback, result, ) continue - # The result is a (str, Optional[callback]) tuple so return the successful result + # The result is a (str, callback | None) tuple so return the successful result return result # If this point has been reached then none of the callbacks successfully authenticated @@ -2353,7 +2340,7 @@ class PasswordAuthProvider: return None async def on_logged_out( - self, user_id: str, device_id: Optional[str], access_token: str + self, user_id: str, device_id: str | None, access_token: str ) -> None: # call all of the on_logged_out callbacks for callback in self.on_logged_out_callbacks: @@ -2367,7 +2354,7 @@ class PasswordAuthProvider: self, uia_results: JsonDict, params: JsonDict, - ) -> Optional[str]: + ) -> str | None: """Defines the username to use when registering the user, using the credentials and parameters provided during the UIA flow. @@ -2412,7 +2399,7 @@ class PasswordAuthProvider: self, uia_results: JsonDict, params: JsonDict, - ) -> Optional[str]: + ) -> str | None: """Defines the display name to use when registering the user, using the credentials and parameters provided during the UIA flow. diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index 438dcf9f2c..dbcf074d2b 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -20,7 +20,7 @@ # import logging import urllib.parse -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from xml.etree import ElementTree as ET import attr @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) class CasError(Exception): """Used to catch errors when validating the CAS ticket.""" - def __init__(self, error: str, error_description: Optional[str] = None): + def __init__(self, error: str, error_description: str | None = None): self.error = error self.error_description = error_description @@ -54,7 +54,7 @@ class CasError(Exception): @attr.s(slots=True, frozen=True, auto_attribs=True) class CasResponse: username: str - attributes: dict[str, list[Optional[str]]] + attributes: dict[str, list[str | None]] class CasHandler: @@ -145,7 +145,7 @@ class CasHandler: except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data - # Assertion is for mypy's benefit. Error.response is Optional[bytes], + # Assertion is for mypy's benefit. Error.response is bytes | None, # but a PartialDownloadError should always have a non-None response. assert pde.response is not None body = pde.response @@ -186,7 +186,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes: dict[str, list[Optional[str]]] = {} + attributes: dict[str, list[str | None]] = {} for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -213,8 +213,8 @@ class CasHandler: async def handle_redirect_request( self, request: SynapseRequest, - client_redirect_url: Optional[bytes], - ui_auth_session_id: Optional[str] = None, + client_redirect_url: bytes | None, + ui_auth_session_id: str | None = None, ) -> str: """Generates a URL for the CAS server where the client should be redirected. @@ -245,8 +245,8 @@ class CasHandler: self, request: SynapseRequest, ticket: str, - client_redirect_url: Optional[str], - session: Optional[str], + client_redirect_url: str | None, + session: str | None, ) -> None: """ Called once the user has successfully authenticated with the SSO. @@ -292,8 +292,8 @@ class CasHandler: self, request: SynapseRequest, cas_response: CasResponse, - client_redirect_url: Optional[str], - session: Optional[str], + client_redirect_url: str | None, + session: str | None, ) -> None: """Handle a CAS response to a ticket request. @@ -384,7 +384,7 @@ class CasHandler: return UserAttributes(localpart=localpart, display_name=display_name) - async def grandfather_existing_users() -> Optional[str]: + async def grandfather_existing_users() -> str | None: # Since CAS did not always use the user_external_ids table, always # to attempt to map to existing users. user_id = UserID(localpart, self._hostname).to_string() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 204dffd288..e4c646ce87 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -20,7 +20,7 @@ # import itertools import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import Membership from synapse.api.errors import SynapseError @@ -76,7 +76,7 @@ class DeactivateAccountHandler: user_id: str, erase_data: bool, requester: Requester, - id_server: Optional[str] = None, + id_server: str | None = None, by_admin: bool = False, ) -> bool: """Deactivate a user's account diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index b89b7416e6..3342420d7d 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -13,7 +13,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.internet.interfaces import IDelayedCall @@ -74,10 +74,10 @@ class DelayedEventsHandler: cfg=self._config.ratelimiting.rc_delayed_event_mgmt, ) - self._next_delayed_event_call: Optional[IDelayedCall] = None + self._next_delayed_event_call: IDelayedCall | None = None # The current position in the current_state_delta stream - self._event_pos: Optional[int] = None + self._event_pos: int | None = None # Guard to ensure we only process event deltas one at a time self._event_processing = False @@ -327,8 +327,8 @@ class DelayedEventsHandler: *, room_id: str, event_type: str, - state_key: Optional[str], - origin_server_ts: Optional[int], + state_key: str | None, + origin_server_ts: int | None, content: JsonDict, delay: int, ) -> str: @@ -526,7 +526,7 @@ class DelayedEventsHandler: state_key=state_key, ) - def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None: + def _schedule_next_at_or_none(self, next_send_ts: Timestamp | None) -> None: if next_send_ts is not None: self._schedule_next_at(next_send_ts) elif self._next_delayed_event_call is not None: @@ -560,7 +560,7 @@ class DelayedEventsHandler: async def _send_event( self, event: DelayedEventDetails, - txn_id: Optional[str] = None, + txn_id: str | None = None, ) -> None: user_id = UserID(event.user_localpart, self._config.server.server_name) user_id_str = user_id.to_string() @@ -622,7 +622,7 @@ class DelayedEventsHandler: def _get_current_ts(self) -> Timestamp: return Timestamp(self._clock.time_msec()) - def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool: + def _next_send_ts_changed(self, next_send_ts: Timestamp | None) -> bool: # The DB alone knows if the next send time changed after adding/modifying # a delayed event, but if we were to ever miss updating our delayed call's # firing time, we may miss other updates. So, keep track of changes to the diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f0558fc737..3f1a5fe6d6 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -27,7 +27,6 @@ from typing import ( AbstractSet, Iterable, Mapping, - Optional, cast, ) @@ -89,7 +88,7 @@ MAX_DEVICE_DISPLAY_NAME_LEN = 100 DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 -def _check_device_name_length(name: Optional[str]) -> None: +def _check_device_name_length(name: str | None) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -208,10 +207,10 @@ class DeviceHandler: async def check_device_registered( self, user_id: str, - device_id: Optional[str], - initial_device_display_name: Optional[str] = None, - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + device_id: str | None, + initial_device_display_name: str | None = None, + auth_provider_id: str | None = None, + auth_provider_session_id: str | None = None, ) -> str: """ If the given device has not been registered, register it with the @@ -269,7 +268,7 @@ class DeviceHandler: @trace async def delete_all_devices_for_user( - self, user_id: str, except_device_id: Optional[str] = None + self, user_id: str, except_device_id: str | None = None ) -> None: """Delete all of the user's devices @@ -344,7 +343,7 @@ class DeviceHandler: await self.notify_device_update(user_id, device_ids) async def upsert_device( - self, user_id: str, device_id: str, display_name: Optional[str] = None + self, user_id: str, device_id: str, display_name: str | None = None ) -> bool: """Create or update a device @@ -425,9 +424,7 @@ class DeviceHandler: log_kv(device_map) return devices - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[tuple[str, JsonDict]]: + async def get_dehydrated_device(self, user_id: str) -> tuple[str, JsonDict] | None: """Retrieve the information for a dehydrated device. Args: @@ -441,10 +438,10 @@ class DeviceHandler: async def store_dehydrated_device( self, user_id: str, - device_id: Optional[str], + device_id: str | None, device_data: JsonDict, - initial_device_display_name: Optional[str] = None, - keys_for_device: Optional[JsonDict] = None, + initial_device_display_name: str | None = None, + keys_for_device: JsonDict | None = None, ) -> str: """Store a dehydrated device for a user, optionally storing the keys associated with it as well. If the user had a previous dehydrated device, it is removed. @@ -563,7 +560,7 @@ class DeviceHandler: user_id: str, room_ids: StrCollection, from_token: StreamToken, - now_token: Optional[StreamToken] = None, + now_token: StreamToken | None = None, ) -> set[str]: """Get the set of users whose devices have changed who share a room with the given user. @@ -677,7 +674,7 @@ class DeviceHandler: memberships_to_fetch.add(delta.prev_event_id) # Fetch all the memberships for the membership events - event_id_to_memberships: Mapping[str, Optional[EventIdMembership]] = {} + event_id_to_memberships: Mapping[str, EventIdMembership | None] = {} if memberships_to_fetch: event_id_to_memberships = await self.store.get_membership_from_event_ids( memberships_to_fetch @@ -834,7 +831,7 @@ class DeviceHandler: # Check if the application services have any results. if self._query_appservices_for_keys: # Query the appservice for all devices for this user. - query: dict[str, Optional[list[str]]] = {user_id: None} + query: dict[str, list[str] | None] = {user_id: None} # Query the appservices for any keys. appservice_results = await self._appservice_handler.query_keys(query) @@ -923,7 +920,7 @@ class DeviceHandler: async def _delete_device_messages( self, task: ScheduledTask, - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`.""" assert task.params is not None user_id = task.params["user_id"] @@ -1335,7 +1332,7 @@ class DeviceListWorkerUpdater: async def multi_user_device_resync( self, user_ids: list[str], - ) -> dict[str, Optional[JsonMapping]]: + ) -> dict[str, JsonMapping | None]: """ Like `user_device_resync` but operates on multiple users **from the same origin** at once. @@ -1359,8 +1356,8 @@ class DeviceListWorkerUpdater: async def process_cross_signing_key_update( self, user_id: str, - master_key: Optional[JsonDict], - self_signing_key: Optional[JsonDict], + master_key: JsonDict | None, + self_signing_key: JsonDict | None, ) -> list[str]: """Process the given new master and self-signing key for the given remote user. @@ -1699,7 +1696,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): async def multi_user_device_resync( self, user_ids: list[str], mark_failed_as_stale: bool = True - ) -> dict[str, Optional[JsonMapping]]: + ) -> dict[str, JsonMapping | None]: """ Like `user_device_resync` but operates on multiple users **from the same origin** at once. @@ -1735,7 +1732,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): async def _user_device_resync_returning_failed( self, user_id: str - ) -> tuple[Optional[JsonMapping], bool]: + ) -> tuple[JsonMapping | None, bool]: """Fetches all devices for a user and updates the device cache with them. Args: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 4dcdcc42fe..0ef14b31da 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -21,7 +21,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes from synapse.api.errors import Codes, SynapseError @@ -315,7 +315,7 @@ class DeviceMessageHandler: self, requester: Requester, device_id: str, - since_token: Optional[str], + since_token: str | None, limit: int, ) -> JsonDict: """Fetches up to `limit` events sent to `device_id` starting from `since_token` diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 865c32d19e..03b23fe0be 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -21,7 +21,7 @@ import logging import string -from typing import TYPE_CHECKING, Iterable, Literal, Optional, Sequence +from typing import TYPE_CHECKING, Iterable, Literal, Sequence from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( @@ -73,8 +73,8 @@ class DirectoryHandler: self, room_alias: RoomAlias, room_id: str, - servers: Optional[Iterable[str]] = None, - creator: Optional[str] = None, + servers: Iterable[str] | None = None, + creator: str | None = None, ) -> None: # general association creation for both human users and app services @@ -108,7 +108,7 @@ class DirectoryHandler: requester: Requester, room_alias: RoomAlias, room_id: str, - servers: Optional[list[str]] = None, + servers: list[str] | None = None, check_membership: bool = True, ) -> None: """Attempt to create a new alias @@ -252,7 +252,7 @@ class DirectoryHandler: ) await self._delete_association(room_alias) - async def _delete_association(self, room_alias: RoomAlias) -> Optional[str]: + async def _delete_association(self, room_alias: RoomAlias) -> str | None: if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -263,16 +263,16 @@ class DirectoryHandler: async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result: Optional[ - RoomAliasMapping - ] = await self.get_association_from_room_alias(room_alias) + result: ( + RoomAliasMapping | None + ) = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id servers = result.servers else: try: - fed_result: Optional[JsonDict] = await self.federation.make_query( + fed_result: JsonDict | None = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, @@ -387,7 +387,7 @@ class DirectoryHandler: async def get_association_from_room_alias( self, room_alias: RoomAlias - ) -> Optional[RoomAliasMapping]: + ) -> RoomAliasMapping | None: result = await self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -395,7 +395,7 @@ class DirectoryHandler: result = await as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool: + def can_modify_alias(self, alias: RoomAlias, user_id: str | None = None) -> bool: # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 85a150b71a..41d27d47da 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -20,7 +20,7 @@ # # import logging -from typing import TYPE_CHECKING, Iterable, Mapping, Optional +from typing import TYPE_CHECKING, Iterable, Mapping import attr from canonicaljson import encode_canonical_json @@ -132,7 +132,7 @@ class E2eKeysHandler: query_body: JsonDict, timeout: int, from_user_id: str, - from_device_id: Optional[str], + from_device_id: str | None, ) -> JsonDict: """Handle a device key query from a client @@ -479,7 +479,7 @@ class E2eKeysHandler: @cancellable async def get_cross_signing_keys_from_cache( - self, query: Iterable[str], from_user_id: Optional[str] + self, query: Iterable[str], from_user_id: str | None ) -> dict[str, dict[str, JsonMapping]]: """Get cross-signing keys for users from the database @@ -527,7 +527,7 @@ class E2eKeysHandler: @cancellable async def query_local_devices( self, - query: Mapping[str, Optional[list[str]]], + query: Mapping[str, list[str] | None], include_displaynames: bool = True, ) -> dict[str, dict[str, dict]]: """Get E2E device keys for local users @@ -542,7 +542,7 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", str(query)) - local_query: list[tuple[str, Optional[str]]] = [] + local_query: list[tuple[str, str | None]] = [] result_dict: dict[str, dict[str, dict]] = {} for user_id, device_ids in query.items(): @@ -594,7 +594,7 @@ class E2eKeysHandler: return result_dict async def on_federation_query_client_keys( - self, query_body: dict[str, dict[str, Optional[list[str]]]] + self, query_body: dict[str, dict[str, list[str] | None]] ) -> JsonDict: """Handle a device key query from a federated server: @@ -614,7 +614,7 @@ class E2eKeysHandler: - self_signing_key: An optional dictionary of user ID -> self-signing key info. """ - device_keys_query: dict[str, Optional[list[str]]] = query_body.get( + device_keys_query: dict[str, list[str] | None] = query_body.get( "device_keys", {} ) if any( @@ -737,7 +737,7 @@ class E2eKeysHandler: self, query: dict[str, dict[str, dict[str, int]]], user: UserID, - timeout: Optional[int], + timeout: int | None, always_include_fallback_keys: bool, ) -> JsonDict: """ @@ -1395,7 +1395,7 @@ class E2eKeysHandler: return signature_list, failures async def _get_e2e_cross_signing_verify_key( - self, user_id: str, key_type: str, from_user_id: Optional[str] = None + self, user_id: str, key_type: str, from_user_id: str | None = None ) -> tuple[JsonMapping, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. @@ -1451,7 +1451,7 @@ class E2eKeysHandler: self, user: UserID, desired_key_type: str, - ) -> Optional[tuple[JsonMapping, str, VerifyKey]]: + ) -> tuple[JsonMapping, str, VerifyKey] | None: """Queries cross-signing keys for a remote user and saves them to the database Only the key specified by `key_type` will be returned, while all retrieved keys @@ -1599,7 +1599,7 @@ class E2eKeysHandler: async def _delete_old_one_time_keys_task( self, task: ScheduledTask - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: """Scheduler task to delete old one time keys. Until Synapse 1.119, Synapse used to issue one-time-keys in a random order, leading to the possibility @@ -1638,7 +1638,7 @@ class E2eKeysHandler: def _check_cross_signing_key( - key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None + key: JsonDict, user_id: str, key_type: str, signing_key: VerifyKey | None = None ) -> None: """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 094b4bc27c..017fbcf8b3 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Literal, Optional, cast +from typing import TYPE_CHECKING, Literal, cast from synapse.api.errors import ( Codes, @@ -63,8 +63,8 @@ class E2eRoomKeysHandler: self, user_id: str, version: str, - room_id: Optional[str] = None, - session_id: Optional[str] = None, + room_id: str | None = None, + session_id: str | None = None, ) -> dict[ Literal["rooms"], dict[str, dict[Literal["sessions"], dict[str, RoomKey]]] ]: @@ -109,8 +109,8 @@ class E2eRoomKeysHandler: self, user_id: str, version: str, - room_id: Optional[str] = None, - session_id: Optional[str] = None, + room_id: str | None = None, + session_id: str | None = None, ) -> JsonDict: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -299,7 +299,7 @@ class E2eRoomKeysHandler: @staticmethod def _should_replace_room_key( - current_room_key: Optional[RoomKey], room_key: RoomKey + current_room_key: RoomKey | None, room_key: RoomKey ) -> bool: """ Determine whether to replace a given current_room_key (if any) @@ -360,7 +360,7 @@ class E2eRoomKeysHandler: return new_version async def get_version_info( - self, user_id: str, version: Optional[str] = None + self, user_id: str, version: str | None = None ) -> JsonDict: """Get the info about a given version of the user's backup @@ -394,7 +394,7 @@ class E2eRoomKeysHandler: return res @trace - async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: + async def delete_version(self, user_id: str, version: str | None = None) -> None: """Deletes a given version of the user's e2e_room_keys backup Args: diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index b2caca8ce7..4f2657bba8 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -19,7 +19,7 @@ # # import logging -from typing import TYPE_CHECKING, Mapping, Optional, Union +from typing import TYPE_CHECKING, Mapping from synapse import event_auth from synapse.api.constants import ( @@ -61,7 +61,7 @@ class EventAuthHandler: async def check_auth_rules_from_context( self, event: EventBase, - batched_auth_events: Optional[Mapping[str, EventBase]] = None, + batched_auth_events: Mapping[str, EventBase] | None = None, ) -> None: """Check an event passes the auth rules at its own auth events Args: @@ -89,7 +89,7 @@ class EventAuthHandler: def compute_auth_events( self, - event: Union[EventBase, EventBuilder], + event: EventBase | EventBuilder, current_state_ids: StateMap[str], for_verification: bool = False, ) -> list[str]: @@ -236,7 +236,7 @@ class EventAuthHandler: state_ids: StateMap[str], room_version: RoomVersion, user_id: str, - prev_membership: Optional[str], + prev_membership: str | None, ) -> None: """ Check whether a user can join a room without an invite due to restricted join rules. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 9522d5a696..ae17639206 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -21,7 +21,7 @@ import logging import random -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import AuthError, SynapseError @@ -58,7 +58,7 @@ class EventStreamHandler: timeout: int = 0, as_client_event: bool = True, affect_presence: bool = True, - room_id: Optional[str] = None, + room_id: str | None = None, ) -> JsonDict: """Fetches the events stream for a given user.""" @@ -152,10 +152,10 @@ class EventHandler: async def get_event( self, user: UserID, - room_id: Optional[str], + room_id: str | None, event_id: str, show_redacted: bool = False, - ) -> Optional[EventBase]: + ) -> EventBase | None: """Retrieve a single specified event. Args: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3eb1d166f8..1bba3fc758 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -31,8 +31,6 @@ from typing import ( TYPE_CHECKING, AbstractSet, Iterable, - Optional, - Union, ) import attr @@ -169,7 +167,7 @@ class FederationHandler: # A dictionary mapping room IDs to (initial destination, other destinations) # tuples. self._partial_state_syncs_maybe_needing_restart: dict[ - str, tuple[Optional[str], AbstractSet[str]] + str, tuple[str | None, AbstractSet[str]] ] = {} # A lock guarding the partial state flag for rooms. # When the lock is held for a given room, no other concurrent code may @@ -232,7 +230,7 @@ class FederationHandler: current_depth: int, limit: int, *, - processing_start_time: Optional[int], + processing_start_time: int | None, ) -> bool: """ Checks whether the `current_depth` is at or approaching any backfill @@ -1174,7 +1172,7 @@ class FederationHandler: user_id: str, membership: str, content: JsonDict, - params: Optional[dict[str, Union[str, Iterable[str]]]] = None, + params: dict[str, str | Iterable[str]] | None = None, ) -> tuple[str, EventBase, RoomVersion]: ( origin, @@ -1371,9 +1369,7 @@ class FederationHandler: return events - async def get_persisted_pdu( - self, origin: str, event_id: str - ) -> Optional[EventBase]: + async def get_persisted_pdu(self, origin: str, event_id: str) -> EventBase | None: """Get an event from the database for the given server. Args: @@ -1670,7 +1666,7 @@ class FederationHandler: logger.debug("Checking auth on event %r", event.content) - last_exception: Optional[Exception] = None + last_exception: Exception | None = None # for each public key in the 3pid invite event for public_key_object in event_auth.get_public_keys(invite_event): @@ -1755,7 +1751,7 @@ class FederationHandler: async def get_room_complexity( self, remote_room_hosts: list[str], room_id: str - ) -> Optional[dict]: + ) -> dict | None: """ Fetch the complexity of a remote room over federation. @@ -1793,7 +1789,7 @@ class FederationHandler: def _start_partial_state_room_sync( self, - initial_destination: Optional[str], + initial_destination: str | None, other_destinations: AbstractSet[str], room_id: str, ) -> None: @@ -1876,7 +1872,7 @@ class FederationHandler: async def _sync_partial_state_room( self, - initial_destination: Optional[str], + initial_destination: str | None, other_destinations: AbstractSet[str], room_id: str, ) -> None: @@ -2018,7 +2014,7 @@ class FederationHandler: def _prioritise_destinations_for_partial_state_resync( - initial_destination: Optional[str], + initial_destination: str | None, other_destinations: AbstractSet[str], room_id: str, ) -> StrCollection: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 32b603e947..01e98f60ad 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -28,7 +28,6 @@ from typing import ( Collection, Container, Iterable, - Optional, Sequence, ) @@ -1818,7 +1817,7 @@ class FederationEventHandler: @trace async def _check_event_auth( - self, origin: Optional[str], event: EventBase, context: EventContext + self, origin: str | None, event: EventBase, context: EventContext ) -> None: """ Checks whether an event should be rejected (for failing auth checks). @@ -2101,7 +2100,7 @@ class FederationEventHandler: event.internal_metadata.soft_failed = True async def _load_or_fetch_auth_events_for_event( - self, destination: Optional[str], event: EventBase + self, destination: str | None, event: EventBase ) -> Collection[EventBase]: """Fetch this event's auth_events, from database or remote diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0f507b3317..1596c55570 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -24,7 +24,7 @@ import logging import urllib.parse -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Awaitable, Callable import attr @@ -106,7 +106,7 @@ class IdentityHandler: async def threepid_from_creds( self, id_server: str, creds: dict[str, str] - ) -> Optional[JsonDict]: + ) -> JsonDict | None: """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server @@ -227,7 +227,7 @@ class IdentityHandler: return data async def try_unbind_threepid( - self, mxid: str, medium: str, address: str, id_server: Optional[str] + self, mxid: str, medium: str, address: str, id_server: str | None ) -> bool: """Attempt to remove a 3PID from one or more identity servers. @@ -338,7 +338,7 @@ class IdentityHandler: client_secret: str, send_attempt: int, send_email_func: Callable[[str, str, str, str], Awaitable], - next_link: Optional[str] = None, + next_link: str | None = None, ) -> str: """Send a threepid validation email for password reset or registration purposes @@ -426,7 +426,7 @@ class IdentityHandler: phone_number: str, client_secret: str, send_attempt: int, - next_link: Optional[str] = None, + next_link: str | None = None, ) -> JsonDict: """ Request an external server send an SMS message on our behalf for the purposes of @@ -473,7 +473,7 @@ class IdentityHandler: async def validate_threepid_session( self, client_secret: str, sid: str - ) -> Optional[JsonDict]: + ) -> JsonDict | None: """Validates a threepid session with only the client secret and session ID Tries validating against any configured account_threepid_delegates as well as locally. @@ -541,7 +541,7 @@ class IdentityHandler: async def lookup_3pid( self, id_server: str, medium: str, address: str, id_access_token: str - ) -> Optional[str]: + ) -> str | None: """Looks up a 3pid in the passed identity server. Args: @@ -567,7 +567,7 @@ class IdentityHandler: async def _lookup_3pid_v2( self, id_server: str, id_access_token: str, medium: str, address: str - ) -> Optional[str]: + ) -> str | None: """Looks up a 3pid in the passed identity server using v2 lookup. Args: @@ -689,7 +689,7 @@ class IdentityHandler: room_avatar_url: str, room_join_rules: str, room_name: str, - room_type: Optional[str], + room_type: str | None, inviter_display_name: str, inviter_avatar_url: str, id_access_token: str, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 1c6f8bf53b..611c4fa7b3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import ( AccountDataTypes, @@ -71,8 +71,8 @@ class InitialSyncHandler: self.snapshot_cache: ResponseCache[ tuple[ str, - Optional[StreamToken], - Optional[StreamToken], + StreamToken | None, + StreamToken | None, Direction, int, bool, diff --git a/synapse/handlers/jwt.py b/synapse/handlers/jwt.py index f1715f6495..67b2b7c31d 100644 --- a/synapse/handlers/jwt.py +++ b/synapse/handlers/jwt.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError @@ -41,7 +41,7 @@ class JwtHandler: self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences - def validate_login(self, login_submission: JsonDict) -> tuple[str, Optional[str]]: + def validate_login(self, login_submission: JsonDict) -> tuple[str, str | None]: """ Authenticates the user for the /login API diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2ad1dbe73f..7679303a36 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Mapping, Sequence from canonicaljson import encode_canonical_json @@ -110,7 +110,7 @@ class MessageHandler: # The scheduled call to self._expire_event. None if no call is currently # scheduled. - self._scheduled_expiry: Optional[IDelayedCall] = None + self._scheduled_expiry: IDelayedCall | None = None if not hs.config.worker.worker_app: self.hs.run_as_background_process( @@ -123,7 +123,7 @@ class MessageHandler: room_id: str, event_type: str, state_key: str, - ) -> Optional[EventBase]: + ) -> EventBase | None: """Get data from a room. Args: @@ -178,8 +178,8 @@ class MessageHandler: self, requester: Requester, room_id: str, - state_filter: Optional[StateFilter] = None, - at_token: Optional[StreamToken] = None, + state_filter: StateFilter | None = None, + at_token: StreamToken | None = None, ) -> list[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -563,7 +563,7 @@ class EventCreationHandler: # Stores the state groups we've recently added to the joined hosts # external cache. Note that the timeout must be significantly less than # the TTL on the external cache. - self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None + self._external_cache_joined_hosts_updates: ExpiringCache | None = None if self._external_cache.is_enabled(): self._external_cache_joined_hosts_updates = ExpiringCache( cache_name="_external_cache_joined_hosts_updates", @@ -577,16 +577,16 @@ class EventCreationHandler: self, requester: Requester, event_dict: dict, - txn_id: Optional[str] = None, - prev_event_ids: Optional[list[str]] = None, - auth_event_ids: Optional[list[str]] = None, - state_event_ids: Optional[list[str]] = None, + txn_id: str | None = None, + prev_event_ids: list[str] | None = None, + auth_event_ids: list[str] | None = None, + state_event_ids: list[str] | None = None, require_consent: bool = True, outlier: bool = False, - depth: Optional[int] = None, - state_map: Optional[StateMap[str]] = None, + depth: int | None = None, + state_map: StateMap[str] | None = None, for_batch: bool = False, - current_state_group: Optional[int] = None, + current_state_group: int | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will @@ -865,7 +865,7 @@ class EventCreationHandler: async def deduplicate_state_event( self, event: EventBase, context: EventContext - ) -> Optional[EventBase]: + ) -> EventBase | None: """ Checks whether event is in the latest resolved state in context. @@ -903,7 +903,7 @@ class EventCreationHandler: requester: Requester, txn_id: str, room_id: str, - ) -> Optional[str]: + ) -> str | None: """For the given transaction ID and room ID, check if there is a matching event ID. Args: @@ -937,7 +937,7 @@ class EventCreationHandler: requester: Requester, txn_id: str, room_id: str, - ) -> Optional[EventBase]: + ) -> EventBase | None: """For the given transaction ID and room ID, check if there is a matching event. If so, fetch it and return it. @@ -961,13 +961,13 @@ class EventCreationHandler: self, requester: Requester, event_dict: dict, - prev_event_ids: Optional[list[str]] = None, - state_event_ids: Optional[list[str]] = None, + prev_event_ids: list[str] | None = None, + state_event_ids: list[str] | None = None, ratelimit: bool = True, - txn_id: Optional[str] = None, + txn_id: str | None = None, ignore_shadow_ban: bool = False, outlier: bool = False, - depth: Optional[int] = None, + depth: int | None = None, ) -> tuple[EventBase, int]: """ Creates an event, then sends it. @@ -1098,13 +1098,13 @@ class EventCreationHandler: self, requester: Requester, event_dict: dict, - prev_event_ids: Optional[list[str]] = None, - state_event_ids: Optional[list[str]] = None, + prev_event_ids: list[str] | None = None, + state_event_ids: list[str] | None = None, ratelimit: bool = True, - txn_id: Optional[str] = None, + txn_id: str | None = None, ignore_shadow_ban: bool = False, outlier: bool = False, - depth: Optional[int] = None, + depth: int | None = None, ) -> tuple[EventBase, int]: room_id = event_dict["room_id"] @@ -1219,14 +1219,14 @@ class EventCreationHandler: async def create_new_client_event( self, builder: EventBuilder, - requester: Optional[Requester] = None, - prev_event_ids: Optional[list[str]] = None, - auth_event_ids: Optional[list[str]] = None, - state_event_ids: Optional[list[str]] = None, - depth: Optional[int] = None, - state_map: Optional[StateMap[str]] = None, + requester: Requester | None = None, + prev_event_ids: list[str] | None = None, + auth_event_ids: list[str] | None = None, + state_event_ids: list[str] | None = None, + depth: int | None = None, + state_map: StateMap[str] | None = None, for_batch: bool = False, - current_state_group: Optional[int] = None, + current_state_group: int | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -1473,7 +1473,7 @@ class EventCreationHandler: requester: Requester, events_and_context: list[EventPersistencePair], ratelimit: bool = True, - extra_users: Optional[list[UserID]] = None, + extra_users: list[UserID] | None = None, ignore_shadow_ban: bool = False, ) -> EventBase: """Processes new events. Please note that if batch persisting events, an error in @@ -1592,7 +1592,7 @@ class EventCreationHandler: self, requester: Requester, room_id: str, - prev_event_id: Optional[str], + prev_event_id: str | None, event_dicts: Sequence[JsonDict], ratelimit: bool = True, ignore_shadow_ban: bool = False, @@ -1685,7 +1685,7 @@ class EventCreationHandler: requester: Requester, events_and_context: list[EventPersistencePair], ratelimit: bool = True, - extra_users: Optional[list[UserID]] = None, + extra_users: list[UserID] | None = None, ) -> EventBase: """Actually persists new events. Should only be called by `handle_new_client_event`, and see its docstring for documentation of @@ -1877,7 +1877,7 @@ class EventCreationHandler: requester: Requester, events_and_context: list[EventPersistencePair], ratelimit: bool = True, - extra_users: Optional[list[UserID]] = None, + extra_users: list[UserID] | None = None, ) -> EventBase: """Called when we have fully built the events, have already calculated the push actions for the events, and checked auth. @@ -2132,7 +2132,7 @@ class EventCreationHandler: return persisted_events[-1] async def is_admin_redaction( - self, event_type: str, sender: str, redacts: Optional[str] + self, event_type: str, sender: str, redacts: str | None ) -> bool: """Return whether the event is a redaction made by an admin, and thus should use a different ratelimiter. @@ -2174,7 +2174,7 @@ class EventCreationHandler: logger.info("maybe_kick_guest_users %r", current_state) await self.hs.get_room_member_handler().kick_guest_users(current_state) - async def _bump_active_time(self, user: UserID, device_id: Optional[str]) -> None: + async def _bump_active_time(self, user: UserID, device_id: str | None) -> None: try: presence = self.hs.get_presence_handler() await presence.bump_presence_active_time(user, device_id) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index f140912b2a..429a739380 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -27,10 +27,8 @@ from typing import ( TYPE_CHECKING, Any, Generic, - Optional, TypedDict, TypeVar, - Union, ) from urllib.parse import urlencode, urlparse @@ -102,10 +100,10 @@ _SESSION_COOKIES = [ class Token(TypedDict): access_token: str token_type: str - id_token: Optional[str] - refresh_token: Optional[str] + id_token: str | None + refresh_token: str | None expires_in: int - scope: Optional[str] + scope: str | None #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but @@ -206,7 +204,7 @@ class OidcHandler: # are two. for cookie_name, _ in _SESSION_COOKIES: - session: Optional[bytes] = request.getCookie(cookie_name) + session: bytes | None = request.getCookie(cookie_name) if session is not None: break else: @@ -335,7 +333,7 @@ class OidcHandler: # Now that we know the audience and the issuer, we can figure out from # what provider it is coming from - oidc_provider: Optional[OidcProvider] = None + oidc_provider: OidcProvider | None = None for provider in self._providers.values(): if provider.issuer == issuer and provider.client_id in audience: oidc_provider = provider @@ -351,7 +349,7 @@ class OidcHandler: class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" - def __init__(self, error: str, error_description: Optional[str] = None): + def __init__(self, error: str, error_description: str | None = None): self.error = error self.error_description = error_description @@ -398,7 +396,7 @@ class OidcProvider: self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method - client_secret: Optional[Union[str, JwtClientSecret]] = None + client_secret: str | JwtClientSecret | None = None if provider.client_secret: client_secret = provider.client_secret elif provider.client_secret_jwt_key: @@ -904,8 +902,8 @@ class OidcProvider: alg_values: list[str], token: str, claims_cls: type[C], - claims_options: Optional[dict] = None, - claims_params: Optional[dict] = None, + claims_options: dict | None = None, + claims_params: dict | None = None, ) -> C: """Decode and validate a JWT, re-fetching the JWKS as needed. @@ -1005,8 +1003,8 @@ class OidcProvider: async def handle_redirect_request( self, request: SynapseRequest, - client_redirect_url: Optional[bytes], - ui_auth_session_id: Optional[str] = None, + client_redirect_url: bytes | None, + ui_auth_session_id: str | None = None, ) -> str: """Handle an incoming request to /login/sso/redirect @@ -1235,7 +1233,7 @@ class OidcProvider: token: Token, request: SynapseRequest, client_redirect_url: str, - sid: Optional[str], + sid: str | None, ) -> None: """Given a UserInfo response, complete the login flow @@ -1300,7 +1298,7 @@ class OidcProvider: return UserAttributes(**attributes) - async def grandfather_existing_users() -> Optional[str]: + async def grandfather_existing_users() -> str | None: if self._allow_existing_users: # If allowing existing users we want to generate a single localpart # and attempt to match it. @@ -1444,8 +1442,8 @@ class OidcProvider: # If the `sub` claim was included in the logout token, we check that it matches # that it matches the right user. We can have cases where the `sub` claim is not # the ID saved in database, so we let admins disable this check in config. - sub: Optional[str] = claims.get("sub") - expected_user_id: Optional[str] = None + sub: str | None = claims.get("sub") + expected_user_id: str | None = None if sub is not None and not self._config.backchannel_logout_ignore_sub: expected_user_id = await self._store.get_user_by_external_id( self.idp_id, sub @@ -1473,7 +1471,7 @@ class LogoutToken(JWTClaims): # type: ignore[misc] REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"] - def validate(self, now: Optional[int] = None, leeway: int = 0) -> None: + def validate(self, now: int | None = None, leeway: int = 0) -> None: """Validate everything in claims payload.""" super().validate(now, leeway) self.validate_sid() @@ -1584,10 +1582,10 @@ class JwtClientSecret: class UserAttributeDict(TypedDict): - localpart: Optional[str] + localpart: str | None confirm_localpart: bool - display_name: Optional[str] - picture: Optional[str] # may be omitted by older `OidcMappingProviders` + display_name: str | None + picture: str | None # may be omitted by older `OidcMappingProviders` emails: list[str] @@ -1674,9 +1672,9 @@ env.filters.update( class JinjaOidcMappingConfig: subject_template: Template picture_template: Template - localpart_template: Optional[Template] - display_name_template: Optional[Template] - email_template: Optional[Template] + localpart_template: Template | None + display_name_template: Template | None + email_template: Template | None extra_attributes: dict[str, Template] confirm_localpart: bool = False @@ -1710,7 +1708,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): subject_template = parse_template_config_with_claim("subject", "sub") picture_template = parse_template_config_with_claim("picture", "picture") - def parse_template_config(option_name: str) -> Optional[Template]: + def parse_template_config(option_name: str) -> Template | None: if option_name not in config: return None try: @@ -1768,7 +1766,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): # a usable mxid. localpart += str(failures) if failures else "" - def render_template_field(template: Optional[Template]) -> Optional[str]: + def render_template_field(template: Template | None) -> str | None: if template is None: return None return template.render(user=userinfo).strip() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7274a512b0..a90ed3193c 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -19,7 +19,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from twisted.python.failure import Failure @@ -132,7 +132,7 @@ class PaginationHandler: ) async def purge_history_for_rooms_in_range( - self, min_ms: Optional[int], max_ms: Optional[int] + self, min_ms: int | None, max_ms: int | None ) -> None: """Purge outdated events from rooms within the given retention range. @@ -279,7 +279,7 @@ class PaginationHandler: async def _purge_history( self, task: ScheduledTask, - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: """ Scheduler action to purge some history of a room. """ @@ -308,7 +308,7 @@ class PaginationHandler: room_id: str, token: str, delete_local_events: bool, - ) -> Optional[str]: + ) -> str | None: """Carry out a history purge on a room. Args: @@ -332,7 +332,7 @@ class PaginationHandler: ) return f.getErrorMessage() - async def get_delete_task(self, delete_id: str) -> Optional[ScheduledTask]: + async def get_delete_task(self, delete_id: str) -> ScheduledTask | None: """Get the current status of an active deleting Args: @@ -342,7 +342,7 @@ class PaginationHandler: return await self._task_scheduler.get_task(delete_id) async def get_delete_tasks_by_room( - self, room_id: str, only_active: Optional[bool] = False + self, room_id: str, only_active: bool | None = False ) -> list[ScheduledTask]: """Get complete, failed or active delete tasks by room @@ -363,7 +363,7 @@ class PaginationHandler: async def _purge_room( self, task: ScheduledTask, - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: """ Scheduler action to purge a room. """ @@ -415,7 +415,7 @@ class PaginationHandler: room_id: str, pagin_config: PaginationConfig, as_client_event: bool = True, - event_filter: Optional[Filter] = None, + event_filter: Filter | None = None, use_admin_priviledge: bool = False, ) -> JsonDict: """Get messages in a room. @@ -691,7 +691,7 @@ class PaginationHandler: async def _shutdown_and_purge_room( self, task: ScheduledTask, - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: """ Scheduler action to shutdown and purge a room. """ @@ -702,7 +702,7 @@ class PaginationHandler: room_id = task.resource_id - async def update_result(result: Optional[JsonMapping]) -> None: + async def update_result(result: JsonMapping | None) -> None: await self._task_scheduler.update_task(task.id, result=result) shutdown_result = ( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d8150a5857..ca5002cab3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -88,7 +88,6 @@ from typing import ( ContextManager, Generator, Iterable, - Optional, ) from prometheus_client import Counter @@ -248,7 +247,7 @@ class BasePresenceHandler(abc.ABC): async def user_syncing( self, user_id: str, - device_id: Optional[str], + device_id: str | None, affect_presence: bool, presence_state: str, ) -> ContextManager[None]: @@ -271,7 +270,7 @@ class BasePresenceHandler(abc.ABC): @abc.abstractmethod def get_currently_syncing_users_for_replication( self, - ) -> Iterable[tuple[str, Optional[str]]]: + ) -> Iterable[tuple[str, str | None]]: """Get an iterable of syncing users and devices on this worker, to send to the presence handler This is called when a replication connection is established. It should return @@ -340,7 +339,7 @@ class BasePresenceHandler(abc.ABC): async def set_state( self, target_user: UserID, - device_id: Optional[str], + device_id: str | None, state: JsonDict, force_notify: bool = False, is_sync: bool = False, @@ -360,7 +359,7 @@ class BasePresenceHandler(abc.ABC): @abc.abstractmethod async def bump_presence_active_time( - self, user: UserID, device_id: Optional[str] + self, user: UserID, device_id: str | None ) -> None: """We've seen the user do something that indicates they're interacting with the app. @@ -370,7 +369,7 @@ class BasePresenceHandler(abc.ABC): self, process_id: str, user_id: str, - device_id: Optional[str], + device_id: str | None, is_syncing: bool, sync_time_msec: int, ) -> None: @@ -496,9 +495,9 @@ class _NullContextManager(ContextManager[None]): def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: pass @@ -517,16 +516,14 @@ class WorkerPresenceHandler(BasePresenceHandler): # The number of ongoing syncs on this process, by (user ID, device ID). # Empty if _presence_enabled is false. - self._user_device_to_num_current_syncs: dict[ - tuple[str, Optional[str]], int - ] = {} + self._user_device_to_num_current_syncs: dict[tuple[str, str | None], int] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() # (user_id, device_id) -> last_sync_ms. Lists the devices that have stopped # syncing but we haven't notified the presence writer of that yet - self._user_devices_going_offline: dict[tuple[str, Optional[str]], int] = {} + self._user_devices_going_offline: dict[tuple[str, str | None], int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -549,7 +546,7 @@ class WorkerPresenceHandler(BasePresenceHandler): def send_user_sync( self, user_id: str, - device_id: Optional[str], + device_id: str | None, is_syncing: bool, last_sync_ms: int, ) -> None: @@ -558,7 +555,7 @@ class WorkerPresenceHandler(BasePresenceHandler): self.instance_id, user_id, device_id, is_syncing, last_sync_ms ) - def mark_as_coming_online(self, user_id: str, device_id: Optional[str]) -> None: + def mark_as_coming_online(self, user_id: str, device_id: str | None) -> None: """A user has started syncing. Send a UserSync to the presence writer, unless they had recently stopped syncing. """ @@ -568,7 +565,7 @@ class WorkerPresenceHandler(BasePresenceHandler): # were offline self.send_user_sync(user_id, device_id, True, self.clock.time_msec()) - def mark_as_going_offline(self, user_id: str, device_id: Optional[str]) -> None: + def mark_as_going_offline(self, user_id: str, device_id: str | None) -> None: """A user has stopped syncing. We wait before notifying the presence writer as its likely they'll come back soon. This allows us to avoid sending a stopped syncing immediately followed by a started syncing @@ -591,7 +588,7 @@ class WorkerPresenceHandler(BasePresenceHandler): async def user_syncing( self, user_id: str, - device_id: Optional[str], + device_id: str | None, affect_presence: bool, presence_state: str, ) -> ContextManager[None]: @@ -699,7 +696,7 @@ class WorkerPresenceHandler(BasePresenceHandler): def get_currently_syncing_users_for_replication( self, - ) -> Iterable[tuple[str, Optional[str]]]: + ) -> Iterable[tuple[str, str | None]]: return [ user_id_device_id for user_id_device_id, count in self._user_device_to_num_current_syncs.items() @@ -709,7 +706,7 @@ class WorkerPresenceHandler(BasePresenceHandler): async def set_state( self, target_user: UserID, - device_id: Optional[str], + device_id: str | None, state: JsonDict, force_notify: bool = False, is_sync: bool = False, @@ -748,7 +745,7 @@ class WorkerPresenceHandler(BasePresenceHandler): ) async def bump_presence_active_time( - self, user: UserID, device_id: Optional[str] + self, user: UserID, device_id: str | None ) -> None: """We've seen the user do something that indicates they're interacting with the app. @@ -786,7 +783,7 @@ class PresenceHandler(BasePresenceHandler): # The per-device presence state, maps user to devices to per-device presence state. self._user_to_device_to_current_state: dict[ - str, dict[Optional[str], UserDevicePresenceState] + str, dict[str | None, UserDevicePresenceState] ] = {} now = self.clock.time_msec() @@ -838,9 +835,7 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self._user_device_to_num_current_syncs: dict[ - tuple[str, Optional[str]], int - ] = {} + self._user_device_to_num_current_syncs: dict[tuple[str, str | None], int] = {} # Keeps track of the number of *ongoing* syncs on other processes. # @@ -853,7 +848,7 @@ class PresenceHandler(BasePresenceHandler): # Stored as a dict from process_id to set of (user_id, device_id), and # a dict of process_id to millisecond timestamp last updated. self.external_process_to_current_syncs: dict[ - str, set[tuple[str, Optional[str]]] + str, set[tuple[str, str | None]] ] = {} self.external_process_last_updated_ms: dict[str, int] = {} @@ -1117,7 +1112,7 @@ class PresenceHandler(BasePresenceHandler): return await self._update_states(changes) async def bump_presence_active_time( - self, user: UserID, device_id: Optional[str] + self, user: UserID, device_id: str | None ) -> None: """We've seen the user do something that indicates they're interacting with the app. @@ -1156,7 +1151,7 @@ class PresenceHandler(BasePresenceHandler): async def user_syncing( self, user_id: str, - device_id: Optional[str], + device_id: str | None, affect_presence: bool = True, presence_state: str = PresenceState.ONLINE, ) -> ContextManager[None]: @@ -1216,7 +1211,7 @@ class PresenceHandler(BasePresenceHandler): def get_currently_syncing_users_for_replication( self, - ) -> Iterable[tuple[str, Optional[str]]]: + ) -> Iterable[tuple[str, str | None]]: # since we are the process handling presence, there is nothing to do here. return [] @@ -1224,7 +1219,7 @@ class PresenceHandler(BasePresenceHandler): self, process_id: str, user_id: str, - device_id: Optional[str], + device_id: str | None, is_syncing: bool, sync_time_msec: int, ) -> None: @@ -1388,7 +1383,7 @@ class PresenceHandler(BasePresenceHandler): async def set_state( self, target_user: UserID, - device_id: Optional[str], + device_id: str | None, state: JsonDict, force_notify: bool = False, is_sync: bool = False, @@ -1835,15 +1830,15 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): async def get_new_events( self, user: UserID, - from_key: Optional[int], + from_key: int | None, # Having a default limit doesn't match the EventSource API, but some # callers do not provide it. It is unused in this class. limit: int = 0, - room_ids: Optional[StrCollection] = None, + room_ids: StrCollection | None = None, is_guest: bool = False, - explicit_room_id: Optional[str] = None, + explicit_room_id: str | None = None, include_offline: bool = True, - service: Optional[ApplicationService] = None, + service: ApplicationService | None = None, ) -> tuple[list[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. @@ -1995,7 +1990,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self, user_id: str, include_offline: bool, - from_key: Optional[int] = None, + from_key: int | None = None, ) -> list[UserPresenceState]: """ Computes the presence updates a user should receive. @@ -2076,8 +2071,8 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): def handle_timeouts( user_states: list[UserPresenceState], is_mine_fn: Callable[[str], bool], - syncing_user_devices: AbstractSet[tuple[str, Optional[str]]], - user_to_devices: dict[str, dict[Optional[str], UserDevicePresenceState]], + syncing_user_devices: AbstractSet[tuple[str, str | None]], + user_to_devices: dict[str, dict[str | None, UserDevicePresenceState]], now: int, ) -> list[UserPresenceState]: """Checks the presence of users that have timed out and updates as @@ -2115,10 +2110,10 @@ def handle_timeouts( def handle_timeout( state: UserPresenceState, is_mine: bool, - syncing_device_ids: AbstractSet[tuple[str, Optional[str]]], - user_devices: dict[Optional[str], UserDevicePresenceState], + syncing_device_ids: AbstractSet[tuple[str, str | None]], + user_devices: dict[str | None, UserDevicePresenceState], now: int, -) -> Optional[UserPresenceState]: +) -> UserPresenceState | None: """Checks the presence of the user to see if any of the timers have elapsed Args: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 240a235a0e..59904cd995 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -20,7 +20,7 @@ # import logging import random -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from synapse.api.constants import ProfileFields from synapse.api.errors import ( @@ -68,8 +68,8 @@ class ProfileHandler: self.user_directory_handler = hs.get_user_directory_handler() self.request_ratelimiter = hs.get_request_ratelimiter() - self.max_avatar_size: Optional[int] = hs.config.server.max_avatar_size - self.allowed_avatar_mimetypes: Optional[list[str]] = ( + self.max_avatar_size: int | None = hs.config.server.max_avatar_size + self.allowed_avatar_mimetypes: list[str] | None = ( hs.config.server.allowed_avatar_mimetypes ) @@ -133,7 +133,7 @@ class ProfileHandler: raise SynapseError(502, "Failed to fetch profile") raise e.to_synapse_error() - async def get_displayname(self, target_user: UserID) -> Optional[str]: + async def get_displayname(self, target_user: UserID) -> str | None: """ Fetch a user's display name from their profile. @@ -211,7 +211,7 @@ class ProfileHandler: 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - displayname_to_set: Optional[str] = new_displayname.strip() + displayname_to_set: str | None = new_displayname.strip() if new_displayname == "": displayname_to_set = None @@ -238,7 +238,7 @@ class ProfileHandler: if propagate: await self._update_join_states(requester, target_user) - async def get_avatar_url(self, target_user: UserID) -> Optional[str]: + async def get_avatar_url(self, target_user: UserID) -> str | None: """ Fetch a user's avatar URL from their profile. @@ -316,7 +316,7 @@ class ProfileHandler: if not await self.check_avatar_size_and_mime_type(new_avatar_url): raise SynapseError(403, "This avatar is not allowed", Codes.FORBIDDEN) - avatar_url_to_set: Optional[str] = new_avatar_url + avatar_url_to_set: str | None = new_avatar_url if new_avatar_url == "": avatar_url_to_set = None @@ -367,9 +367,9 @@ class ProfileHandler: server_name = host if self._is_mine_server_name(server_name): - media_info: Optional[ - Union[LocalMedia, RemoteMedia] - ] = await self.store.get_local_media(media_id) + media_info: ( + LocalMedia | RemoteMedia | None + ) = await self.store.get_local_media(media_id) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -606,7 +606,7 @@ class ProfileHandler: ) async def check_profile_query_allowed( - self, target_user: UserID, requester: Optional[UserID] = None + self, target_user: UserID, requester: UserID | None = None ) -> None: """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py index 643fa72f3f..746c712bac 100644 --- a/synapse/handlers/push_rules.py +++ b/synapse/handlers/push_rules.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import attr @@ -40,7 +40,7 @@ class RuleSpec: scope: str template: str rule_id: str - attr: Optional[str] + attr: str | None class PushRulesHandler: @@ -51,7 +51,7 @@ class PushRulesHandler: self._main_store = hs.get_datastores().main async def set_rule_attr( - self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + self, user_id: str, spec: RuleSpec, val: bool | JsonDict ) -> None: """Set an attribute (enabled or actions) on an existing push rule. @@ -137,7 +137,7 @@ class PushRulesHandler: return rules -def check_actions(actions: list[Union[str, JsonDict]]) -> None: +def check_actions(actions: list[str | JsonDict]) -> None: """Check if the given actions are spec compliant. Args: diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index ad41113b5b..f6383baf0b 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -19,7 +19,7 @@ # # import logging -from typing import TYPE_CHECKING, Iterable, Optional, Sequence +from typing import TYPE_CHECKING, Iterable, Sequence from synapse.api.constants import EduTypes, ReceiptTypes from synapse.appservice import ApplicationService @@ -180,7 +180,7 @@ class ReceiptsHandler: receipt_type: str, user_id: UserID, event_id: str, - thread_id: Optional[str], + thread_id: str | None, ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -285,8 +285,8 @@ class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]): limit: int, room_ids: Iterable[str], is_guest: bool, - explicit_room_id: Optional[str] = None, - to_key: Optional[MultiWriterStreamToken] = None, + explicit_room_id: str | None = None, + to_key: MultiWriterStreamToken | None = None, ) -> tuple[list[JsonMapping], MultiWriterStreamToken]: """ Find read receipts for given rooms (> `from_token` and <= `to_token`) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8b620a91bc..139c14dcf4 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,7 +26,6 @@ import logging from typing import ( TYPE_CHECKING, Iterable, - Optional, TypedDict, ) @@ -106,8 +105,8 @@ def init_counters_for_auth_provider(auth_provider_id: str, server_name: str) -> class LoginDict(TypedDict): device_id: str access_token: str - valid_until_ms: Optional[int] - refresh_token: Optional[str] + valid_until_ms: int | None + refresh_token: str | None class RegistrationHandler: @@ -160,8 +159,8 @@ class RegistrationHandler: async def check_username( self, localpart: str, - guest_access_token: Optional[str] = None, - assigned_user_id: Optional[str] = None, + guest_access_token: str | None = None, + assigned_user_id: str | None = None, inhibit_user_in_use_error: bool = False, ) -> None: if types.contains_invalid_mxid_characters(localpart): @@ -228,19 +227,19 @@ class RegistrationHandler: async def register_user( self, - localpart: Optional[str] = None, - password_hash: Optional[str] = None, - guest_access_token: Optional[str] = None, + localpart: str | None = None, + password_hash: str | None = None, + guest_access_token: str | None = None, make_guest: bool = False, admin: bool = False, - threepid: Optional[dict] = None, - user_type: Optional[str] = None, - default_display_name: Optional[str] = None, - address: Optional[str] = None, - bind_emails: Optional[Iterable[str]] = None, + threepid: dict | None = None, + user_type: str | None = None, + default_display_name: str | None = None, + address: str | None = None, + bind_emails: Iterable[str] | None = None, by_admin: bool = False, - user_agent_ips: Optional[list[tuple[str, str]]] = None, - auth_provider_id: Optional[str] = None, + user_agent_ips: list[tuple[str, str]] | None = None, + auth_provider_id: str | None = None, approved: bool = False, ) -> str: """Registers a new client on the server. @@ -679,7 +678,7 @@ class RegistrationHandler: return (user_id, service) def check_user_id_not_appservice_exclusive( - self, user_id: str, allowed_appservice: Optional[ApplicationService] = None + self, user_id: str, allowed_appservice: ApplicationService | None = None ) -> None: # don't allow people to register the server notices mxid if self._server_notices_mxid is not None: @@ -704,7 +703,7 @@ class RegistrationHandler: errcode=Codes.EXCLUSIVE, ) - async def check_registration_ratelimit(self, address: Optional[str]) -> None: + async def check_registration_ratelimit(self, address: str | None) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -723,14 +722,14 @@ class RegistrationHandler: async def register_with_store( self, user_id: str, - password_hash: Optional[str] = None, + password_hash: str | None = None, was_guest: bool = False, make_guest: bool = False, - appservice_id: Optional[str] = None, - create_profile_with_displayname: Optional[str] = None, + appservice_id: str | None = None, + create_profile_with_displayname: str | None = None, admin: bool = False, - user_type: Optional[str] = None, - address: Optional[str] = None, + user_type: str | None = None, + address: str | None = None, shadow_banned: bool = False, approved: bool = False, ) -> None: @@ -771,14 +770,14 @@ class RegistrationHandler: async def register_device( self, user_id: str, - device_id: Optional[str], - initial_display_name: Optional[str], + device_id: str | None, + initial_display_name: str | None, is_guest: bool = False, is_appservice_ghost: bool = False, - auth_provider_id: Optional[str] = None, + auth_provider_id: str | None = None, should_issue_refresh_token: bool = False, - auth_provider_session_id: Optional[str] = None, - ) -> tuple[str, str, Optional[int], Optional[str]]: + auth_provider_session_id: str | None = None, + ) -> tuple[str, str, int | None, str | None]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -821,13 +820,13 @@ class RegistrationHandler: async def register_device_inner( self, user_id: str, - device_id: Optional[str], - initial_display_name: Optional[str], + device_id: str | None, + initial_display_name: str | None, is_guest: bool = False, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + auth_provider_id: str | None = None, + auth_provider_session_id: str | None = None, ) -> LoginDict: """Helper for register_device @@ -927,7 +926,7 @@ class RegistrationHandler: } async def post_registration_actions( - self, user_id: str, auth_result: dict, access_token: Optional[str] + self, user_id: str, auth_result: dict, access_token: str | None ) -> None: """A user has completed registration @@ -977,7 +976,7 @@ class RegistrationHandler: await self.post_consent_actions(user_id) async def _register_email_threepid( - self, user_id: str, threepid: dict, token: Optional[str] + self, user_id: str, threepid: dict, token: str | None ) -> None: """Add an email address as a 3pid identifier diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 217681f7c0..fd38ffa920 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -25,7 +25,6 @@ from typing import ( Collection, Iterable, Mapping, - Optional, Sequence, ) @@ -75,9 +74,9 @@ class BundledAggregations: Some values require additional processing during serialization. """ - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None + references: JsonDict | None = None + replace: EventBase | None = None + thread: _ThreadAggregation | None = None def __bool__(self) -> bool: return bool(self.references or self.replace or self.thread) @@ -101,8 +100,8 @@ class RelationsHandler: pagin_config: PaginationConfig, recurse: bool, include_original_event: bool, - relation_type: Optional[str] = None, - event_type: Optional[str] = None, + relation_type: str | None = None, + event_type: str | None = None, ) -> JsonDict: """Get related events of a event, ordered by topological ordering. @@ -553,7 +552,7 @@ class RelationsHandler: room_id: str, include: ThreadsListInclude, limit: int = 5, - from_token: Optional[ThreadsNextBatch] = None, + from_token: ThreadsNextBatch | None = None, ) -> JsonDict: """Get related events of a event, ordered by topological ordering. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f242accef1..d62ad5393f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,7 +33,6 @@ from typing import ( Any, Awaitable, Callable, - Optional, cast, ) @@ -198,7 +197,7 @@ class RoomCreationHandler: requester: Requester, old_room_id: str, new_version: RoomVersion, - additional_creators: Optional[list[str]], + additional_creators: list[str] | None, auto_member: bool = False, ratelimit: bool = True, ) -> str: @@ -341,10 +340,11 @@ class RoomCreationHandler: new_version: RoomVersion, tombstone_event: EventBase, tombstone_context: synapse.events.snapshot.EventContext, - additional_creators: Optional[list[str]], - creation_event_with_context: Optional[ - tuple[EventBase, synapse.events.snapshot.EventContext] - ] = None, + additional_creators: list[str] | None, + creation_event_with_context: tuple[ + EventBase, synapse.events.snapshot.EventContext + ] + | None = None, auto_member: bool = False, ) -> str: """ @@ -434,7 +434,7 @@ class RoomCreationHandler: old_room_id: str, new_room_id: str, old_room_state: StateMap[str], - additional_creators: Optional[list[str]], + additional_creators: list[str] | None, ) -> None: """Send updated power levels in both rooms after an upgrade @@ -524,9 +524,9 @@ class RoomCreationHandler: def _calculate_upgraded_room_creation_content( self, old_room_create_event: EventBase, - tombstone_event_id: Optional[str], + tombstone_event_id: str | None, new_room_version: RoomVersion, - additional_creators: Optional[list[str]], + additional_creators: list[str] | None, ) -> JsonDict: creation_content: JsonDict = { "room_version": new_room_version.identifier, @@ -558,10 +558,11 @@ class RoomCreationHandler: new_room_id: str, new_room_version: RoomVersion, tombstone_event_id: str, - additional_creators: Optional[list[str]], - creation_event_with_context: Optional[ - tuple[EventBase, synapse.events.snapshot.EventContext] - ] = None, + additional_creators: list[str] | None, + creation_event_with_context: tuple[ + EventBase, synapse.events.snapshot.EventContext + ] + | None = None, auto_member: bool = False, ) -> None: """Populate a new room based on an old room @@ -597,7 +598,7 @@ class RoomCreationHandler: initial_state: MutableStateMap = {} # Replicate relevant room events - types_to_copy: list[tuple[str, Optional[str]]] = [ + types_to_copy: list[tuple[str, str | None]] = [ (EventTypes.JoinRules, ""), (EventTypes.Name, ""), (EventTypes.Topic, ""), @@ -1039,9 +1040,9 @@ class RoomCreationHandler: requester: Requester, config: JsonDict, ratelimit: bool = True, - creator_join_profile: Optional[JsonDict] = None, + creator_join_profile: JsonDict | None = None, ignore_forced_encryption: bool = False, - ) -> tuple[str, Optional[RoomAlias], int]: + ) -> tuple[str, RoomAlias | None, int]: """Creates a new room. Args: @@ -1426,13 +1427,14 @@ class RoomCreationHandler: invite_list: list[str], initial_state: MutableStateMap, creation_content: JsonDict, - room_alias: Optional[RoomAlias] = None, - power_level_content_override: Optional[JsonDict] = None, - creator_join_profile: Optional[JsonDict] = None, + room_alias: RoomAlias | None = None, + power_level_content_override: JsonDict | None = None, + creator_join_profile: JsonDict | None = None, ignore_forced_encryption: bool = False, - creation_event_with_context: Optional[ - tuple[EventBase, synapse.events.snapshot.EventContext] - ] = None, + creation_event_with_context: tuple[ + EventBase, synapse.events.snapshot.EventContext + ] + | None = None, ) -> tuple[int, str, int]: """Sends the initial events into a new room. Sends the room creation, membership, and power level events into the room sequentially, then creates and batches up the @@ -1813,7 +1815,7 @@ class RoomCreationHandler: self, users_map: dict[str, int], creator: str, - additional_creators: Optional[list[str]], + additional_creators: list[str] | None, ) -> None: creators = [creator] if additional_creators: @@ -1880,9 +1882,9 @@ class RoomContextHandler: room_id: str, event_id: str, limit: int, - event_filter: Optional[Filter], + event_filter: Filter | None, use_admin_priviledge: bool = False, - ) -> Optional[EventContext]: + ) -> EventContext | None: """Retrieves events, pagination tokens and state around a given event in a room. @@ -2168,7 +2170,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): limit: int, room_ids: StrCollection, is_guest: bool, - explicit_room_id: Optional[str] = None, + explicit_room_id: str | None = None, ) -> tuple[list[EventBase], RoomStreamToken]: # We just ignore the key for now. @@ -2244,11 +2246,10 @@ class RoomShutdownHandler: self, room_id: str, params: ShutdownRoomParams, - result: Optional[ShutdownRoomResponse] = None, - update_result_fct: Optional[ - Callable[[Optional[JsonMapping]], Awaitable[None]] - ] = None, - ) -> Optional[ShutdownRoomResponse]: + result: ShutdownRoomResponse | None = None, + update_result_fct: Callable[[JsonMapping | None], Awaitable[None]] + | None = None, + ) -> ShutdownRoomResponse | None: """ Shuts down a room. Moves all joined local users and room aliases automatically to a new room if `new_room_user_id` is set. Otherwise local users only diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 97a5d07c7c..6377931b39 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import attr import msgpack @@ -67,14 +67,14 @@ class RoomListHandler: self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ - tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] + tuple[int | None, str | None, ThirdPartyInstanceID | None] ] = ResponseCache( clock=hs.get_clock(), name="room_list", server_name=self.server_name, ) self.remote_response_cache: ResponseCache[ - tuple[str, Optional[int], Optional[str], bool, Optional[str]] + tuple[str, int | None, str | None, bool, str | None] ] = ResponseCache( clock=hs.get_clock(), name="remote_room_list", @@ -84,11 +84,11 @@ class RoomListHandler: async def get_local_public_room_list( self, - limit: Optional[int] = None, - since_token: Optional[str] = None, - search_filter: Optional[dict] = None, - network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, - from_federation_origin: Optional[str] = None, + limit: int | None = None, + since_token: str | None = None, + search_filter: dict | None = None, + network_tuple: ThirdPartyInstanceID | None = EMPTY_THIRD_PARTY_ID, + from_federation_origin: str | None = None, ) -> JsonDict: """Generate a local public room list. @@ -150,10 +150,10 @@ class RoomListHandler: async def _get_public_room_list( self, limit: int, - since_token: Optional[str] = None, - search_filter: Optional[dict] = None, - network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, - from_federation_origin: Optional[str] = None, + since_token: str | None = None, + search_filter: dict | None = None, + network_tuple: ThirdPartyInstanceID | None = EMPTY_THIRD_PARTY_ID, + from_federation_origin: str | None = None, ) -> JsonDict: """Generate a public room list. Args: @@ -175,7 +175,7 @@ class RoomListHandler: if since_token: batch_token = RoomListNextBatch.from_token(since_token) - bounds: Optional[tuple[int, str]] = ( + bounds: tuple[int, str] | None = ( batch_token.last_joined_members, batch_token.last_room_id, ) @@ -235,8 +235,8 @@ class RoomListHandler: # `len(room_entries) >= limit` and we might be left with rooms we didn't # 'consider' (iterate over) and we should save those rooms for the next # batch. - first_considered_room: Optional[LargestRoomStats] = None - last_considered_room: Optional[LargestRoomStats] = None + first_considered_room: LargestRoomStats | None = None + last_considered_room: LargestRoomStats | None = None cut_off_due_to_limit: bool = False for room_result in rooms_iterator: @@ -349,7 +349,7 @@ class RoomListHandler: cache_context: _CacheContext, with_alias: bool = True, allow_private: bool = False, - ) -> Optional[JsonMapping]: + ) -> JsonMapping | None: """Returns the entry for a room Args: @@ -455,11 +455,11 @@ class RoomListHandler: async def get_remote_public_room_list( self, server_name: str, - limit: Optional[int] = None, - since_token: Optional[str] = None, - search_filter: Optional[dict] = None, + limit: int | None = None, + since_token: str | None = None, + search_filter: dict | None = None, include_all_networks: bool = False, - third_party_instance_id: Optional[str] = None, + third_party_instance_id: str | None = None, ) -> JsonDict: """Get the public room list from remote server @@ -531,11 +531,11 @@ class RoomListHandler: async def _get_remote_list_cached( self, server_name: str, - limit: Optional[int] = None, - since_token: Optional[str] = None, - search_filter: Optional[dict] = None, + limit: int | None = None, + since_token: str | None = None, + search_filter: dict | None = None, include_all_networks: bool = False, - third_party_instance_id: Optional[str] = None, + third_party_instance_id: str | None = None, ) -> JsonDict: """Wrapper around FederationClient.get_public_rooms that caches the result. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 03cfc99260..d5f72c1732 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -23,7 +23,7 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable from synapse import types from synapse.api.constants import ( @@ -260,7 +260,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def remote_reject_invite( self, invite_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> tuple[str, int]: @@ -283,7 +283,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def remote_rescind_knock( self, knock_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> tuple[str, int]: @@ -349,8 +349,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def ratelimit_multiple_invites( self, - requester: Optional[Requester], - room_id: Optional[str], + requester: Requester | None, + room_id: str | None, n_invites: int, update: bool = True, ) -> None: @@ -374,8 +374,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def ratelimit_invite( self, - requester: Optional[Requester], - room_id: Optional[str], + requester: Requester | None, + room_id: str | None, invitee_user_id: str, ) -> None: """Ratelimit invites by room and by target user. @@ -396,15 +396,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): target: UserID, room_id: str, membership: str, - prev_event_ids: Optional[list[str]] = None, - state_event_ids: Optional[list[str]] = None, - depth: Optional[int] = None, - txn_id: Optional[str] = None, + prev_event_ids: list[str] | None = None, + state_event_ids: list[str] | None = None, + depth: int | None = None, + txn_id: str | None = None, ratelimit: bool = True, - content: Optional[dict] = None, + content: dict | None = None, require_consent: bool = True, outlier: bool = False, - origin_server_ts: Optional[int] = None, + origin_server_ts: int | None = None, ) -> tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -572,18 +572,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): target: UserID, room_id: str, action: str, - txn_id: Optional[str] = None, - remote_room_hosts: Optional[list[str]] = None, - third_party_signed: Optional[dict] = None, + txn_id: str | None = None, + remote_room_hosts: list[str] | None = None, + third_party_signed: dict | None = None, ratelimit: bool = True, - content: Optional[dict] = None, + content: dict | None = None, new_room: bool = False, require_consent: bool = True, outlier: bool = False, - prev_event_ids: Optional[list[str]] = None, - state_event_ids: Optional[list[str]] = None, - depth: Optional[int] = None, - origin_server_ts: Optional[int] = None, + prev_event_ids: list[str] | None = None, + state_event_ids: list[str] | None = None, + depth: int | None = None, + origin_server_ts: int | None = None, ) -> tuple[str, int]: """Update a user's membership in a room. @@ -686,18 +686,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): target: UserID, room_id: str, action: str, - txn_id: Optional[str] = None, - remote_room_hosts: Optional[list[str]] = None, - third_party_signed: Optional[dict] = None, + txn_id: str | None = None, + remote_room_hosts: list[str] | None = None, + third_party_signed: dict | None = None, ratelimit: bool = True, - content: Optional[dict] = None, + content: dict | None = None, new_room: bool = False, require_consent: bool = True, outlier: bool = False, - prev_event_ids: Optional[list[str]] = None, - state_event_ids: Optional[list[str]] = None, - depth: Optional[int] = None, - origin_server_ts: Optional[int] = None, + prev_event_ids: list[str] | None = None, + state_event_ids: list[str] | None = None, + depth: int | None = None, + origin_server_ts: int | None = None, ) -> tuple[str, int]: """Helper for update_membership. @@ -1420,7 +1420,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def send_membership_event( self, - requester: Optional[Requester], + requester: Requester | None, event: EventBase, context: EventContext, ratelimit: bool = True, @@ -1594,7 +1594,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return RoomID.from_string(room_id), servers - async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]: + async def _get_inviter(self, user_id: str, room_id: str) -> UserID | None: invite = await self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) @@ -1610,10 +1610,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): address: str, id_server: str, requester: Requester, - txn_id: Optional[str], + txn_id: str | None, id_access_token: str, - prev_event_ids: Optional[list[str]] = None, - depth: Optional[int] = None, + prev_event_ids: list[str] | None = None, + depth: int | None = None, ) -> tuple[str, int]: """Invite a 3PID to a room. @@ -1724,10 +1724,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): address: str, room_id: str, user: UserID, - txn_id: Optional[str], + txn_id: str | None, id_access_token: str, - prev_event_ids: Optional[list[str]] = None, - depth: Optional[int] = None, + prev_event_ids: list[str] | None = None, + depth: int | None = None, ) -> tuple[EventBase, int]: room_state = await self._storage_controllers.state.get_current_state( room_id, @@ -1864,7 +1864,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def _is_remote_room_too_complex( self, room_id: str, remote_room_hosts: list[str] - ) -> Optional[bool]: + ) -> bool | None: """ Check if complexity of a remote room is too great. @@ -1977,7 +1977,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def remote_reject_invite( self, invite_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> tuple[str, int]: @@ -2014,7 +2014,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def remote_rescind_knock( self, knock_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> tuple[str, int]: @@ -2043,7 +2043,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def _generate_local_out_of_band_leave( self, previous_membership_event: EventBase, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> tuple[str, int]: @@ -2180,7 +2180,7 @@ class RoomForgetterHandler(StateDeltasHandler): self._room_member_handler = hs.get_room_member_handler() # The current position in the current_state_delta stream - self.pos: Optional[int] = None + self.pos: int | None = None # Guard to ensure we only process deltas one at a time self._is_processing = False diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 0927c031f7..b56519ab0a 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler from synapse.replication.http.membership import ( @@ -73,7 +73,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): async def remote_reject_invite( self, invite_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: dict, ) -> tuple[str, int]: @@ -93,7 +93,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): async def remote_rescind_knock( self, knock_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> tuple[str, int]: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index a3247d3cda..9ec0d33f11 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -71,7 +71,7 @@ class _PaginationKey: # during a pagination session). room_id: str suggested_only: bool - max_depth: Optional[int] + max_depth: int | None # The randomly generated token. token: str @@ -118,10 +118,10 @@ class RoomSummaryHandler: bool, bool, bool, - Optional[int], - Optional[int], - Optional[str], - Optional[tuple[str, ...]], + int | None, + int | None, + str | None, + tuple[str, ...] | None, ] ] = ResponseCache( clock=hs.get_clock(), @@ -137,10 +137,10 @@ class RoomSummaryHandler: suggested_only: bool = False, omit_remote_room_hierarchy: bool = False, admin_skip_room_visibility_check: bool = False, - max_depth: Optional[int] = None, - limit: Optional[int] = None, - from_token: Optional[str] = None, - remote_room_hosts: Optional[tuple[str, ...]] = None, + max_depth: int | None = None, + limit: int | None = None, + from_token: str | None = None, + remote_room_hosts: tuple[str, ...] | None = None, ) -> JsonDict: """ Implementation of the room hierarchy C-S API. @@ -208,10 +208,10 @@ class RoomSummaryHandler: suggested_only: bool = False, omit_remote_room_hierarchy: bool = False, admin_skip_room_visibility_check: bool = False, - max_depth: Optional[int] = None, - limit: Optional[int] = None, - from_token: Optional[str] = None, - remote_room_hosts: Optional[tuple[str, ...]] = None, + max_depth: int | None = None, + limit: int | None = None, + from_token: str | None = None, + remote_room_hosts: tuple[str, ...] | None = None, ) -> JsonDict: """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" @@ -480,8 +480,8 @@ class RoomSummaryHandler: async def _summarize_local_room( self, - requester: Optional[str], - origin: Optional[str], + requester: str | None, + origin: str | None, room_id: str, suggested_only: bool, include_children: bool = True, @@ -594,7 +594,7 @@ class RoomSummaryHandler: ) async def _is_local_room_accessible( - self, room_id: str, requester: Optional[str], origin: Optional[str] = None + self, room_id: str, requester: str | None, origin: str | None = None ) -> bool: """ Calculate whether the room should be shown to the requester. @@ -723,7 +723,7 @@ class RoomSummaryHandler: return False async def _is_remote_room_accessible( - self, requester: Optional[str], room_id: str, room: JsonDict + self, requester: str | None, room_id: str, room: JsonDict ) -> bool: """ Calculate whether the room received over federation should be shown to the requester. @@ -864,9 +864,9 @@ class RoomSummaryHandler: async def get_room_summary( self, - requester: Optional[str], + requester: str | None, room_id: str, - remote_room_hosts: Optional[list[str]] = None, + remote_room_hosts: list[str] | None = None, ) -> JsonDict: """ Implementation of the room summary C-S API from MSC3266 @@ -965,7 +965,7 @@ class _RoomQueueEntry: depth: int = 0 # The room summary for this room returned via federation. This will only be # used if the room is not known locally (and is not a space). - remote_room: Optional[JsonDict] = None + remote_room: JsonDict | None = None @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -1026,7 +1026,7 @@ _INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") def _child_events_comparison_key( child: EventBase, -) -> tuple[bool, Optional[str], int, str]: +) -> tuple[bool, str | None, int, str]: """ Generate a value for comparing two child events for ordering. diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 218fbcaaa7..8f2b37c46d 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -20,7 +20,7 @@ # import logging import re -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable import attr import saml2 @@ -54,7 +54,7 @@ class Saml2SessionData: creation_time: int # The user interactive authentication session ID associated with this SAML # session (or None if this SAML session is for an initial login). - ui_auth_session_id: Optional[str] = None + ui_auth_session_id: str | None = None class SamlHandler: @@ -98,8 +98,8 @@ class SamlHandler: async def handle_redirect_request( self, request: SynapseRequest, - client_redirect_url: Optional[bytes], - ui_auth_session_id: Optional[str] = None, + client_redirect_url: bytes | None, + ui_auth_session_id: str | None = None, ) -> str: """Handle an incoming request to /login/sso/redirect @@ -303,7 +303,7 @@ class SamlHandler: emails=result.get("emails", []), ) - async def grandfather_existing_users() -> Optional[str]: + async def grandfather_existing_users() -> str | None: # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid if ( @@ -341,7 +341,7 @@ class SamlHandler: def _remote_id_from_saml_response( self, saml2_auth: saml2.response.AuthnResponse, - client_redirect_url: Optional[str], + client_redirect_url: str | None, ) -> str: """Extract the unique remote id from a SAML2 AuthnResponse diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 8f39c6ec6b..20b38427a6 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -21,7 +21,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable import attr from unpaddedbase64 import decode_base64, encode_base64 @@ -117,7 +117,7 @@ class SearchHandler: return historical_room_ids async def search( - self, requester: Requester, content: JsonDict, batch: Optional[str] = None + self, requester: Requester, content: JsonDict, batch: str | None = None ) -> JsonDict: """Performs a full text search for a user. @@ -226,18 +226,18 @@ class SearchHandler: async def _search( self, requester: Requester, - batch_group: Optional[str], - batch_group_key: Optional[str], - batch_token: Optional[str], + batch_group: str | None, + batch_group_key: str | None, + batch_token: str | None, search_term: str, keys: list[str], filter_dict: JsonDict, order_by: str, include_state: bool, group_keys: list[str], - event_context: Optional[bool], - before_limit: Optional[int], - after_limit: Optional[int], + event_context: bool | None, + before_limit: int | None, + after_limit: int | None, include_profile: bool, ) -> JsonDict: """Performs a full text search for a user. @@ -307,7 +307,7 @@ class SearchHandler: } } - sender_group: Optional[dict[str, JsonDict]] + sender_group: dict[str, JsonDict] | None if order_by == "rank": search_result, sender_group = await self._search_by_rank( @@ -517,10 +517,10 @@ class SearchHandler: search_term: str, keys: Iterable[str], search_filter: Filter, - batch_group: Optional[str], - batch_group_key: Optional[str], - batch_token: Optional[str], - ) -> tuple[_SearchResult, Optional[str]]: + batch_group: str | None, + batch_group_key: str | None, + batch_token: str | None, + ) -> tuple[_SearchResult, str | None]: """ Performs a full text search for a user ordering by recent. diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 02fd48dbad..8cdf9c6a87 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -24,7 +24,7 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from io import BytesIO -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.internet.defer import Deferred from twisted.internet.endpoints import HostnameEndpoint @@ -49,13 +49,13 @@ async def _sendmail( from_addr: str, to_addr: str, msg_bytes: bytes, - username: Optional[bytes] = None, - password: Optional[bytes] = None, + username: bytes | None = None, + password: bytes | None = None, require_auth: bool = False, require_tls: bool = False, enable_tls: bool = True, force_tls: bool = False, - tlsname: Optional[str] = None, + tlsname: str | None = None, ) -> None: """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests @@ -136,7 +136,7 @@ class SendEmailHandler: app_name: str, html: str, text: str, - additional_headers: Optional[dict[str, str]] = None, + additional_headers: dict[str, str] | None = None, ) -> None: """Send a multipart email with the given information. diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 54116a9b72..042cb4e1b5 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -18,7 +18,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester @@ -42,7 +42,7 @@ class SetPasswordHandler: user_id: str, password_hash: str, logout_devices: bool, - requester: Optional[Requester] = None, + requester: Requester | None = None, ) -> None: if not self._auth_handler.can_change_password(): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index cea4b857ee..6a5d5c7b3c 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -15,7 +15,7 @@ import itertools import logging from itertools import chain -from typing import TYPE_CHECKING, AbstractSet, Mapping, Optional +from typing import TYPE_CHECKING, AbstractSet, Mapping from prometheus_client import Histogram from typing_extensions import assert_never @@ -114,7 +114,7 @@ class SlidingSyncHandler: self, requester: Requester, sync_config: SlidingSyncConfig, - from_token: Optional[SlidingSyncStreamToken] = None, + from_token: SlidingSyncStreamToken | None = None, timeout_ms: int = 0, ) -> tuple[SlidingSyncResult, bool]: """ @@ -201,7 +201,7 @@ class SlidingSyncHandler: self, sync_config: SlidingSyncConfig, to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken] = None, + from_token: SlidingSyncStreamToken | None = None, ) -> SlidingSyncResult: """ Generates the response body of a Sliding Sync result, represented as a @@ -550,7 +550,7 @@ class SlidingSyncHandler: room_id: str, room_sync_config: RoomSyncConfig, room_membership_for_user_at_to_token: RoomsForUserType, - from_token: Optional[SlidingSyncStreamToken], + from_token: SlidingSyncStreamToken | None, to_token: StreamToken, newly_joined: bool, newly_left: bool, @@ -678,10 +678,10 @@ class SlidingSyncHandler: # `invite`/`knock` rooms only have `stripped_state`. See # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 timeline_events: list[EventBase] = [] - bundled_aggregations: Optional[dict[str, BundledAggregations]] = None - limited: Optional[bool] = None - prev_batch_token: Optional[StreamToken] = None - num_live: Optional[int] = None + bundled_aggregations: dict[str, BundledAggregations] | None = None + limited: bool | None = None + prev_batch_token: StreamToken | None = None + num_live: int | None = None if ( room_sync_config.timeline_limit > 0 # No timeline for invite/knock rooms (just `stripped_state`) @@ -850,7 +850,7 @@ class SlidingSyncHandler: # For incremental syncs, we can do this first to determine if something relevant # has changed and strategically avoid fetching other costly things. room_state_delta_id_map: MutableStateMap[str] = {} - name_event_id: Optional[str] = None + name_event_id: str | None = None membership_changed = False name_changed = False avatar_changed = False @@ -914,7 +914,7 @@ class SlidingSyncHandler: # We only need the room summary for calculating heroes, however if we do # fetch it then we can use it to calculate `joined_count` and # `invited_count`. - room_membership_summary: Optional[Mapping[str, MemberSummary]] = None + room_membership_summary: Mapping[str, MemberSummary] | None = None # `heroes` are required if the room name is not set. # @@ -950,8 +950,8 @@ class SlidingSyncHandler: # # Similarly to other metadata, we only need to calculate the member # counts if this is an initial sync or the memberships have changed. - joined_count: Optional[int] = None - invited_count: Optional[int] = None + joined_count: int | None = None + invited_count: int | None = None if ( initial or membership_changed ) and room_membership_for_user_at_to_token.membership == Membership.JOIN: @@ -1036,7 +1036,7 @@ class SlidingSyncHandler: ) required_state_filter = StateFilter.all() else: - required_state_types: list[tuple[str, Optional[str]]] = [] + required_state_types: list[tuple[str, str | None]] = [] num_wild_state_keys = 0 lazy_load_room_members = False num_others = 0 @@ -1146,7 +1146,7 @@ class SlidingSyncHandler: # The required state map to store in the room sync config, if it has # changed. - changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None + changed_required_state_map: Mapping[str, AbstractSet[str]] | None = None # We can return all of the state that was requested if this was the first # time we've sent the room down this connection. @@ -1205,7 +1205,7 @@ class SlidingSyncHandler: required_room_state = required_state_filter.filter_state(room_state) # Find the room name and avatar from the state - room_name: Optional[str] = None + room_name: str | None = None # TODO: Should we also check for `EventTypes.CanonicalAlias` # (`m.room.canonical_alias`) as a fallback for the room name? see # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153 @@ -1213,7 +1213,7 @@ class SlidingSyncHandler: if name_event is not None: room_name = name_event.content.get("name") - room_avatar: Optional[str] = None + room_avatar: str | None = None avatar_event = room_state.get((EventTypes.RoomAvatar, "")) if avatar_event is not None: room_avatar = avatar_event.content.get("url") @@ -1376,7 +1376,7 @@ class SlidingSyncHandler: to_token: StreamToken, timeline: list[EventBase], check_outside_timeline: bool, - ) -> Optional[int]: + ) -> int | None: """Get a bump stamp for the room, if we have a bump event and it has changed. @@ -1479,7 +1479,7 @@ def _required_state_changes( prev_required_state_map: Mapping[str, AbstractSet[str]], request_required_state_map: Mapping[str, AbstractSet[str]], state_deltas: StateMap[str], -) -> tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]: +) -> tuple[Mapping[str, AbstractSet[str]] | None, StateFilter]: """Calculates the changes between the required state room config from the previous requests compared with the current request. @@ -1528,7 +1528,7 @@ def _required_state_changes( # The set of types/state keys that we need to fetch and return to the # client. Passed to `StateFilter.from_types(...)` - added: list[tuple[str, Optional[str]]] = [] + added: list[tuple[str, str | None]] = [] # Convert the list of state deltas to map from type to state_keys that have # changed. diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 221af86f7d..d076bec51a 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -20,7 +20,6 @@ from typing import ( ChainMap, Mapping, MutableMapping, - Optional, Sequence, cast, ) @@ -86,7 +85,7 @@ class SlidingSyncExtensionHandler: actual_room_ids: set[str], actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], + from_token: SlidingSyncStreamToken | None, ) -> SlidingSyncResult.Extensions: """Handle extension requests. @@ -202,8 +201,8 @@ class SlidingSyncExtensionHandler: def find_relevant_room_ids_for_extension( self, - requested_lists: Optional[StrCollection], - requested_room_ids: Optional[StrCollection], + requested_lists: StrCollection | None, + requested_room_ids: StrCollection | None, actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList], actual_room_ids: AbstractSet[str], ) -> set[str]: @@ -246,7 +245,7 @@ class SlidingSyncExtensionHandler: if requested_lists is not None: for list_key in requested_lists: # Just some typing because we share the variable name in multiple places - actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None + actual_list: SlidingSyncResult.SlidingWindowList | None = None # A wildcard means we process rooms from all lists if list_key == "*": @@ -277,7 +276,7 @@ class SlidingSyncExtensionHandler: sync_config: SlidingSyncConfig, to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, to_token: StreamToken, - ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]: + ) -> SlidingSyncResult.Extensions.ToDeviceExtension | None: """Handle to-device extension (MSC3885) Args: @@ -352,8 +351,8 @@ class SlidingSyncExtensionHandler: sync_config: SlidingSyncConfig, e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension, to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]: + from_token: SlidingSyncStreamToken | None, + ) -> SlidingSyncResult.Extensions.E2eeExtension | None: """Handle E2EE device extension (MSC3884) Args: @@ -369,7 +368,7 @@ class SlidingSyncExtensionHandler: if not e2ee_request.enabled: return None - device_list_updates: Optional[DeviceListUpdates] = None + device_list_updates: DeviceListUpdates | None = None if from_token is not None: # TODO: This should take into account the `from_token` and `to_token` device_list_updates = await self.device_handler.get_user_ids_changed( @@ -407,8 +406,8 @@ class SlidingSyncExtensionHandler: actual_room_ids: set[str], account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension, to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]: + from_token: SlidingSyncStreamToken | None, + ) -> SlidingSyncResult.Extensions.AccountDataExtension | None: """Handle Account Data extension (MSC3959) Args: @@ -640,8 +639,8 @@ class SlidingSyncExtensionHandler: actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension, to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]: + from_token: SlidingSyncStreamToken | None, + ) -> SlidingSyncResult.Extensions.ReceiptsExtension | None: """Handle Receipts extension (MSC3960) Args: @@ -844,8 +843,8 @@ class SlidingSyncExtensionHandler: actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], typing_request: SlidingSyncConfig.Extensions.TypingExtension, to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.TypingExtension]: + from_token: SlidingSyncStreamToken | None, + ) -> SlidingSyncResult.Extensions.TypingExtension | None: """Handle Typing Notification extension (MSC3961) Args: @@ -905,8 +904,8 @@ class SlidingSyncExtensionHandler: sync_config: SlidingSyncConfig, thread_subscriptions_request: SlidingSyncConfig.Extensions.ThreadSubscriptionsExtension, to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.ThreadSubscriptionsExtension]: + from_token: SlidingSyncStreamToken | None, + ) -> SlidingSyncResult.Extensions.ThreadSubscriptionsExtension | None: """Handle Thread Subscriptions extension (MSC4308) Args: diff --git a/synapse/handlers/sliding_sync/room_lists.py b/synapse/handlers/sliding_sync/room_lists.py index fc77fd3c65..3d11902236 100644 --- a/synapse/handlers/sliding_sync/room_lists.py +++ b/synapse/handlers/sliding_sync/room_lists.py @@ -21,8 +21,6 @@ from typing import ( Literal, Mapping, MutableMapping, - Optional, - Union, cast, ) @@ -81,7 +79,7 @@ logger = logging.getLogger(__name__) # Helper definition for the types that we might return. We do this to avoid # copying data between types (which can be expensive for many rooms). -RoomsForUserType = Union[RoomsForUserStateReset, RoomsForUser, RoomsForUserSlidingSync] +RoomsForUserType = RoomsForUserStateReset | RoomsForUser | RoomsForUserSlidingSync @attr.s(auto_attribs=True, slots=True, frozen=True) @@ -184,7 +182,7 @@ class SlidingSyncRoomLists: sync_config: SlidingSyncConfig, previous_connection_state: "PerConnectionState", to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> SlidingSyncInterestedRooms: """Fetch the set of rooms that match the request""" has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 @@ -221,7 +219,7 @@ class SlidingSyncRoomLists: sync_config: SlidingSyncConfig, previous_connection_state: "PerConnectionState", to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> SlidingSyncInterestedRooms: """Implementation of `compute_interested_rooms` using new sliding sync db tables.""" user_id = sync_config.user.to_string() @@ -620,7 +618,7 @@ class SlidingSyncRoomLists: sync_config: SlidingSyncConfig, previous_connection_state: "PerConnectionState", to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> SlidingSyncInterestedRooms: """Fallback code when the database background updates haven't completed yet.""" @@ -806,7 +804,7 @@ class SlidingSyncRoomLists: async def _filter_relevant_rooms_to_send( self, previous_connection_state: PerConnectionState, - from_token: Optional[StreamToken], + from_token: StreamToken | None, relevant_room_map: dict[str, RoomSyncConfig], ) -> dict[str, RoomSyncConfig]: """Filters the `relevant_room_map` down to those rooms that may have @@ -879,7 +877,7 @@ class SlidingSyncRoomLists: user: UserID, rooms_for_user: Mapping[str, RoomsForUserType], to_token: StreamToken, - ) -> Mapping[str, Optional[RoomsForUser]]: + ) -> Mapping[str, RoomsForUser | None]: """ Takes the current set of rooms for a user (retrieved after the given token), and returns the changes needed to "rewind" it to match the set of @@ -962,7 +960,7 @@ class SlidingSyncRoomLists: # Otherwise we're about to make changes to `rooms_for_user`, so we turn # it into a mutable dict. - changes: dict[str, Optional[RoomsForUser]] = {} + changes: dict[str, RoomsForUser | None] = {} # Assemble a list of the first membership event after the `to_token` so we can # step backward to the previous membership that would apply to the from/to @@ -1028,7 +1026,7 @@ class SlidingSyncRoomLists: self, user: UserID, to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> tuple[dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]: """ Fetch room IDs that the user has had membership in (the full room list including @@ -1138,7 +1136,7 @@ class SlidingSyncRoomLists: self, user_id: str, to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> tuple[AbstractSet[str], Mapping[str, RoomsForUserStateReset]]: """Fetch the sets of rooms that the user newly joined or left in the given token range. @@ -1185,7 +1183,7 @@ class SlidingSyncRoomLists: self, user_id: str, to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> tuple[AbstractSet[str], Mapping[str, RoomsForUserStateReset]]: """Fetch the sets of rooms that the user newly joined or left in the given token range. @@ -1400,7 +1398,7 @@ class SlidingSyncRoomLists: room_id: str, room_membership_for_user_map: dict[str, RoomsForUserType], to_token: StreamToken, - ) -> Optional[RoomsForUserType]: + ) -> RoomsForUserType | None: """ Check whether the user is allowed to see the room based on whether they have ever had membership in the room or if the room is `world_readable`. @@ -1466,7 +1464,7 @@ class SlidingSyncRoomLists: self, room_ids: StrCollection, sync_room_map: dict[str, RoomsForUserType], - ) -> dict[str, Optional[StateMap[StrippedStateEvent]]]: + ) -> dict[str, StateMap[StrippedStateEvent] | None]: """ Fetch stripped state for a list of room IDs. Stripped state is only applicable to invite/knock rooms. Other rooms will have `None` as their @@ -1485,7 +1483,7 @@ class SlidingSyncRoomLists: event. """ room_id_to_stripped_state_map: dict[ - str, Optional[StateMap[StrippedStateEvent]] + str, StateMap[StrippedStateEvent] | None ] = {} # Fetch what we haven't before @@ -1530,7 +1528,7 @@ class SlidingSyncRoomLists: f"Unexpected membership {membership} (this is a problem with Synapse itself)" ) - stripped_state_map: Optional[MutableStateMap[StrippedStateEvent]] = None + stripped_state_map: MutableStateMap[StrippedStateEvent] | None = None # Scrutinize unsigned things. `raw_stripped_state_events` should be a list # of stripped events if raw_stripped_state_events is not None: @@ -1564,10 +1562,8 @@ class SlidingSyncRoomLists: room_ids: set[str], sync_room_map: dict[str, RoomsForUserType], to_token: StreamToken, - room_id_to_stripped_state_map: dict[ - str, Optional[StateMap[StrippedStateEvent]] - ], - ) -> Mapping[str, Union[Optional[str], StateSentinel]]: + room_id_to_stripped_state_map: dict[str, StateMap[StrippedStateEvent] | None], + ) -> Mapping[str, str | None | StateSentinel]: """ Get the given state event content for a list of rooms. First we check the current state of the room, then fallback to stripped state if available, then @@ -1589,7 +1585,7 @@ class SlidingSyncRoomLists: the given state event (event_type, ""), otherwise `None`. Rooms unknown to this server will return `ROOM_UNKNOWN_SENTINEL`. """ - room_id_to_content: dict[str, Union[Optional[str], StateSentinel]] = {} + room_id_to_content: dict[str, str | None | StateSentinel] = {} # As a bulk shortcut, use the current state if the server is particpating in the # room (meaning we have current state). Ideally, for leave/ban rooms, we would @@ -1750,7 +1746,7 @@ class SlidingSyncRoomLists: user_id = user.to_string() room_id_to_stripped_state_map: dict[ - str, Optional[StateMap[StrippedStateEvent]] + str, StateMap[StrippedStateEvent] | None ] = {} filtered_room_id_set = set(sync_room_map.keys()) @@ -2107,7 +2103,7 @@ class SlidingSyncRoomLists: self, sync_room_map: dict[str, RoomsForUserType], to_token: StreamToken, - limit: Optional[int] = None, + limit: int | None = None, ) -> list[RoomsForUserType]: """ Sort by `stream_ordering` of the last event that the user should see in the diff --git a/synapse/handlers/sliding_sync/store.py b/synapse/handlers/sliding_sync/store.py index d24fccf76f..7bcd5f27ea 100644 --- a/synapse/handlers/sliding_sync/store.py +++ b/synapse/handlers/sliding_sync/store.py @@ -13,7 +13,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import attr @@ -66,7 +66,7 @@ class SlidingSyncConnectionStore: async def get_and_clear_connection_positions( self, sync_config: SlidingSyncConfig, - from_token: Optional[SlidingSyncStreamToken], + from_token: SlidingSyncStreamToken | None, ) -> PerConnectionState: """Fetch the per-connection state for the token. @@ -93,7 +93,7 @@ class SlidingSyncConnectionStore: async def record_new_state( self, sync_config: SlidingSyncConfig, - from_token: Optional[SlidingSyncStreamToken], + from_token: SlidingSyncStreamToken | None, new_connection_state: MutablePerConnectionState, ) -> int: """Record updated per-connection state, returning the connection diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 641241287e..ebbe7afa84 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -30,7 +30,6 @@ from typing import ( Iterable, Mapping, NoReturn, - Optional, Protocol, ) from urllib.parse import urlencode @@ -102,12 +101,12 @@ class SsoIdentityProvider(Protocol): """User-facing name for this provider""" @property - def idp_icon(self) -> Optional[str]: + def idp_icon(self) -> str | None: """Optional MXC URI for user-facing icon""" return None @property - def idp_brand(self) -> Optional[str]: + def idp_brand(self) -> str | None: """Optional branding identifier""" return None @@ -115,8 +114,8 @@ class SsoIdentityProvider(Protocol): async def handle_redirect_request( self, request: SynapseRequest, - client_redirect_url: Optional[bytes], - ui_auth_session_id: Optional[str] = None, + client_redirect_url: bytes | None, + ui_auth_session_id: str | None = None, ) -> str: """Handle an incoming request to /login/sso/redirect @@ -141,10 +140,10 @@ class UserAttributes: # the localpart of the mxid that the mapper has assigned to the user. # if `None`, the mapper has not picked a userid, and the user should be prompted to # enter one. - localpart: Optional[str] + localpart: str | None confirm_localpart: bool = False - display_name: Optional[str] = None - picture: Optional[str] = None + display_name: str | None = None + picture: str | None = None # mypy thinks these are incompatible for some reason. emails: StrCollection = attr.Factory(list) @@ -157,19 +156,19 @@ class UsernameMappingSession: auth_provider_id: str # An optional session ID from the IdP. - auth_provider_session_id: Optional[str] + auth_provider_session_id: str | None # user ID on the IdP server remote_user_id: str # attributes returned by the ID mapper - display_name: Optional[str] + display_name: str | None emails: StrCollection - avatar_url: Optional[str] + avatar_url: str | None # An optional dictionary of extra attributes to be provided to the client in the # login response. - extra_login_attributes: Optional[JsonDict] + extra_login_attributes: JsonDict | None # where to redirect the client back to client_redirect_url: str @@ -178,11 +177,11 @@ class UsernameMappingSession: expiry_time_ms: int # choices made by the user - chosen_localpart: Optional[str] = None + chosen_localpart: str | None = None use_display_name: bool = True use_avatar: bool = True emails_to_use: StrCollection = () - terms_accepted_version: Optional[str] = None + terms_accepted_version: str | None = None # the HTTP cookie used to track the mapping session id @@ -278,7 +277,7 @@ class SsoHandler: self, request: Request, error: str, - error_description: Optional[str] = None, + error_description: str | None = None, code: int = 400, ) -> None: """Renders the error template and responds with it. @@ -302,7 +301,7 @@ class SsoHandler: self, request: SynapseRequest, client_redirect_url: bytes, - idp_id: Optional[str], + idp_id: str | None, ) -> str: """Handle a request to /login/sso/redirect @@ -321,7 +320,7 @@ class SsoHandler: ) # if the client chose an IdP, use that - idp: Optional[SsoIdentityProvider] = None + idp: SsoIdentityProvider | None = None if idp_id: idp = self._identity_providers.get(idp_id) if not idp: @@ -341,7 +340,7 @@ class SsoHandler: async def get_sso_user_by_remote_user_id( self, auth_provider_id: str, remote_user_id: str - ) -> Optional[str]: + ) -> str | None: """ Maps the user ID of a remote IdP to a mxid for a previously seen user. @@ -389,9 +388,9 @@ class SsoHandler: request: SynapseRequest, client_redirect_url: str, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], - grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], - extra_login_attributes: Optional[JsonDict] = None, - auth_provider_session_id: Optional[str] = None, + grandfather_existing_users: Callable[[], Awaitable[str | None]], + extra_login_attributes: JsonDict | None = None, + auth_provider_session_id: str | None = None, registration_enabled: bool = True, ) -> None: """ @@ -582,8 +581,8 @@ class SsoHandler: def _get_url_for_next_new_user_step( self, - attributes: Optional[UserAttributes] = None, - session: Optional[UsernameMappingSession] = None, + attributes: UserAttributes | None = None, + session: UsernameMappingSession | None = None, ) -> bytes: """Returns the URL to redirect to for the next step of new user registration @@ -622,8 +621,8 @@ class SsoHandler: attributes: UserAttributes, client_redirect_url: str, next_step_url: bytes, - extra_login_attributes: Optional[JsonDict], - auth_provider_session_id: Optional[str], + extra_login_attributes: JsonDict | None, + auth_provider_session_id: str | None, ) -> NoReturn: """Creates a UsernameMappingSession and redirects the browser @@ -1175,7 +1174,7 @@ class SsoHandler: self, auth_provider_id: str, auth_provider_session_id: str, - expected_user_id: Optional[str] = None, + expected_user_id: str | None = None, ) -> None: """Revoke any devices and in-flight logins tied to a provider session. diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index 2fbe407a63..db63f0483d 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -20,7 +20,7 @@ import logging from enum import Enum, auto -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from synapse.server import HomeServer @@ -40,8 +40,8 @@ class StateDeltasHandler: async def _get_key_change( self, - prev_event_id: Optional[str], - event_id: Optional[str], + prev_event_id: str | None, + event_id: str | None, key_name: str, public_value: str, ) -> MatchChange: diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 0804f72c47..6d661453ac 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -26,7 +26,6 @@ from typing import ( Any, Counter as CounterType, Iterable, - Optional, ) from synapse.api.constants import EventContentFields, EventTypes, Membership @@ -62,7 +61,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats.stats_enabled # The current position in the current_state_delta stream - self.pos: Optional[int] = None + self.pos: int | None = None # Guard to ensure we only process deltas one at a time self._is_processing = False diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a19b75203b..b534e24698 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -25,7 +25,6 @@ from typing import ( AbstractSet, Any, Mapping, - Optional, Sequence, ) @@ -116,7 +115,7 @@ class SyncConfig: user: UserID filter_collection: FilterCollection is_guest: bool - device_id: Optional[str] + device_id: str | None use_state_after: bool @@ -127,7 +126,7 @@ class TimelineBatch: limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. - bundled_aggregations: Optional[dict[str, BundledAggregations]] = None + bundled_aggregations: dict[str, BundledAggregations] | None = None def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -150,7 +149,7 @@ class JoinedSyncResult: account_data: list[JsonDict] unread_notifications: JsonDict unread_thread_notifications: JsonDict - summary: Optional[JsonDict] + summary: JsonDict | None unread_count: int def __bool__(self) -> bool: @@ -314,7 +313,7 @@ class SyncHandler: # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache: ExpiringCache[ - tuple[str, Optional[str]], LruCache[str, str] + tuple[str, str | None], LruCache[str, str] ] = ExpiringCache( cache_name="lazy_loaded_members_cache", server_name=self.server_name, @@ -331,7 +330,7 @@ class SyncHandler: requester: Requester, sync_config: SyncConfig, request_key: SyncRequestKey, - since_token: Optional[StreamToken] = None, + since_token: StreamToken | None = None, timeout: int = 0, full_state: bool = False, ) -> SyncResult: @@ -372,7 +371,7 @@ class SyncHandler: async def _wait_for_sync_for_user( self, sync_config: SyncConfig, - since_token: Optional[StreamToken], + since_token: StreamToken | None, timeout: int, full_state: bool, cache_context: ResponseCacheContext[SyncRequestKey], @@ -502,7 +501,7 @@ class SyncHandler: async def current_sync_for_user( self, sync_config: SyncConfig, - since_token: Optional[StreamToken] = None, + since_token: StreamToken | None = None, full_state: bool = False, ) -> SyncResult: """ @@ -537,7 +536,7 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", now_token: StreamToken, - since_token: Optional[StreamToken] = None, + since_token: StreamToken | None = None, ) -> tuple[StreamToken, dict[str, list[JsonDict]]]: """Get the ephemeral events for each room the user is in Args: @@ -604,8 +603,8 @@ class SyncHandler: sync_result_builder: "SyncResultBuilder", sync_config: SyncConfig, upto_token: StreamToken, - since_token: Optional[StreamToken] = None, - potential_recents: Optional[list[EventBase]] = None, + since_token: StreamToken | None = None, + potential_recents: list[EventBase] | None = None, newly_joined_room: bool = False, ) -> TimelineBatch: """Create a timeline batch for the room @@ -850,7 +849,7 @@ class SyncHandler: batch: TimelineBatch, state: MutableStateMap[EventBase], now_token: StreamToken, - ) -> Optional[JsonDict]: + ) -> JsonDict | None: """Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds @@ -963,11 +962,9 @@ class SyncHandler: return summary def get_lazy_loaded_members_cache( - self, cache_key: tuple[str, Optional[str]] + self, cache_key: tuple[str, str | None] ) -> LruCache[str, str]: - cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get( - cache_key - ) + cache: LruCache[str, str] | None = self.lazy_loaded_members_cache.get(cache_key) if cache is None: logger.debug("creating LruCache for %r", cache_key) cache = LruCache( @@ -985,7 +982,7 @@ class SyncHandler: room_id: str, batch: TimelineBatch, sync_config: SyncConfig, - since_token: Optional[StreamToken], + since_token: StreamToken | None, end_token: StreamToken, full_state: bool, joined: bool, @@ -1024,11 +1021,11 @@ class SyncHandler: ): # The memberships needed for events in the timeline. # Only calculated when `lazy_load_members` is on. - members_to_fetch: Optional[set[str]] = None + members_to_fetch: set[str] | None = None # A dictionary mapping user IDs to the first event in the timeline sent by # them. Only calculated when `lazy_load_members` is on. - first_event_by_sender_map: Optional[dict[str, EventBase]] = None + first_event_by_sender_map: dict[str, EventBase] | None = None # The contribution to the room state from state events in the timeline. # Only contains the last event for any given state key. @@ -1172,7 +1169,7 @@ class SyncHandler: sync_config: SyncConfig, batch: TimelineBatch, end_token: StreamToken, - members_to_fetch: Optional[set[str]], + members_to_fetch: set[str] | None, timeline_state: StateMap[str], joined: bool, ) -> StateMap[str]: @@ -1322,7 +1319,7 @@ class SyncHandler: batch: TimelineBatch, since_token: StreamToken, end_token: StreamToken, - members_to_fetch: Optional[set[str]], + members_to_fetch: set[str] | None, timeline_state: StateMap[str], ) -> StateMap[str]: """Calculate the state events to be included in an incremental sync response. @@ -1649,7 +1646,7 @@ class SyncHandler: async def generate_sync_result( self, sync_config: SyncConfig, - since_token: Optional[StreamToken] = None, + since_token: StreamToken | None = None, full_state: bool = False, ) -> SyncResult: """Generates the response body of a sync result. @@ -1804,7 +1801,7 @@ class SyncHandler: async def get_sync_result_builder( self, sync_config: SyncConfig, - since_token: Optional[StreamToken] = None, + since_token: StreamToken | None = None, full_state: bool = False, ) -> "SyncResultBuilder": """ @@ -2439,7 +2436,7 @@ class SyncHandler: # This is all screaming out for a refactor, as the logic here is # subtle and the moving parts numerous. if leave_event.internal_metadata.is_out_of_band_membership(): - batch_events: Optional[list[EventBase]] = [leave_event] + batch_events: list[EventBase] | None = [leave_event] else: batch_events = None @@ -2608,7 +2605,7 @@ class SyncHandler: sync_result_builder: "SyncResultBuilder", room_builder: "RoomSyncResultBuilder", ephemeral: list[JsonDict], - tags: Optional[Mapping[str, JsonMapping]], + tags: Mapping[str, JsonMapping] | None, account_data: Mapping[str, JsonMapping], always_include: bool = False, ) -> None: @@ -2758,7 +2755,7 @@ class SyncHandler: # An out of band room won't have any state changes. state = {} - summary: Optional[JsonDict] = {} + summary: JsonDict | None = {} # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form @@ -3007,7 +3004,7 @@ class SyncResultBuilder: sync_config: SyncConfig full_state: bool - since_token: Optional[StreamToken] + since_token: StreamToken | None now_token: StreamToken joined_room_ids: frozenset[str] excluded_room_ids: frozenset[str] @@ -3100,10 +3097,10 @@ class RoomSyncResultBuilder: room_id: str rtype: str - events: Optional[list[EventBase]] + events: list[EventBase] | None newly_joined: bool full_state: bool - since_token: Optional[StreamToken] + since_token: StreamToken | None upto_token: StreamToken end_token: StreamToken out_of_band: bool = False diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py index d56c915e0a..539672c7fe 100644 --- a/synapse/handlers/thread_subscriptions.py +++ b/synapse/handlers/thread_subscriptions.py @@ -1,6 +1,6 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import RelationTypes from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError @@ -29,7 +29,7 @@ class ThreadSubscriptionsHandler: user_id: UserID, room_id: str, thread_root_event_id: str, - ) -> Optional[ThreadSubscription]: + ) -> ThreadSubscription | None: """Get thread subscription settings for a specific thread and user. Checks that the thread root is both a real event and also that it is visible to the user. @@ -62,8 +62,8 @@ class ThreadSubscriptionsHandler: room_id: str, thread_root_event_id: str, *, - automatic_event_id: Optional[str], - ) -> Optional[int]: + automatic_event_id: str | None, + ) -> int | None: """Sets or updates a user's subscription settings for a specific thread root. Args: @@ -146,7 +146,7 @@ class ThreadSubscriptionsHandler: async def unsubscribe_user_from_thread( self, user_id: UserID, room_id: str, thread_root_event_id: str - ) -> Optional[int]: + ) -> int | None: """Clears a user's subscription settings for a specific thread root. Args: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 17e43858c9..8b577d5d58 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -20,7 +20,7 @@ # import logging import random -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable import attr @@ -576,8 +576,8 @@ class TypingNotificationEventSource(EventSource[int, JsonMapping]): limit: int, room_ids: Iterable[str], is_guest: bool, - explicit_room_id: Optional[str] = None, - to_key: Optional[int] = None, + explicit_room_id: str | None = None, + to_key: int | None = None, ) -> tuple[list[JsonMapping], int]: """ Find typing notifications for given rooms (> `from_token` and <= `to_token`) diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index cbae33eaec..a0097dbc96 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -136,7 +136,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response - # For mypy's benefit. A general Error.response is Optional[bytes], but + # For mypy's benefit. A general Error.response is bytes | None, but # a PartialDownloadError.response should be bytes AFAICS. assert data is not None resp_body = json_decoder.decode(data.decode("utf-8")) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index fd05aff4c8..e5210a3e97 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -21,7 +21,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.internet.interfaces import IDelayedCall @@ -116,7 +116,7 @@ class UserDirectoryHandler(StateDeltasHandler): self._hs = hs # The current position in the current_state_delta stream - self.pos: Optional[int] = None + self.pos: int | None = None # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -124,7 +124,7 @@ class UserDirectoryHandler(StateDeltasHandler): # Guard to ensure we only have one process for refreshing remote profiles self._is_refreshing_remote_profiles = False # Handle to cancel the `call_later` of `kick_off_remote_profile_refresh_process` - self._refresh_remote_profiles_call_later: Optional[IDelayedCall] = None + self._refresh_remote_profiles_call_later: IDelayedCall | None = None # Guard to ensure we only have one process for refreshing remote profiles # for the given servers. @@ -299,8 +299,8 @@ class UserDirectoryHandler(StateDeltasHandler): async def _handle_room_publicity_change( self, room_id: str, - prev_event_id: Optional[str], - event_id: Optional[str], + prev_event_id: str | None, + event_id: str | None, typ: str, ) -> None: """Handle a room having potentially changed from/to world_readable/publicly @@ -372,8 +372,8 @@ class UserDirectoryHandler(StateDeltasHandler): async def _handle_room_membership_event( self, room_id: str, - prev_event_id: Optional[str], - event_id: Optional[str], + prev_event_id: str | None, + event_id: str | None, state_key: str, ) -> None: """Process a single room membershp event. @@ -519,7 +519,7 @@ class UserDirectoryHandler(StateDeltasHandler): self, user_id: str, room_id: str, - prev_event_id: Optional[str], + prev_event_id: str | None, event_id: str, ) -> None: """Check member event changes for any profile changes and update the diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index af5498c560..3e097d21f2 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -26,8 +26,6 @@ from typing import ( TYPE_CHECKING, AsyncContextManager, Collection, - Optional, - Union, ) from weakref import WeakSet @@ -72,9 +70,7 @@ class WorkerLocksHandler: # Map from lock name/key to set of `WaitingLock` that are active for # that lock. - self._locks: dict[ - tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]] - ] = {} + self._locks: dict[tuple[str, str], WeakSet[WaitingLock | WaitingMultiLock]] = {} self._clock.looping_call(self._cleanup_locks, 30_000) @@ -185,7 +181,7 @@ class WorkerLocksHandler: return def _wake_all_locks( - locks: Collection[Union[WaitingLock, WaitingMultiLock]], + locks: Collection[WaitingLock | WaitingMultiLock], ) -> None: for lock in locks: deferred = lock.deferred @@ -211,9 +207,9 @@ class WaitingLock: handler: WorkerLocksHandler lock_name: str lock_key: str - write: Optional[bool] + write: bool | None deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) - _inner_lock: Optional[Lock] = None + _inner_lock: Lock | None = None _retry_interval: float = 0.1 _lock_span: "opentracing.Scope" = attr.Factory( lambda: start_active_span("WaitingLock.lock") @@ -258,10 +254,10 @@ class WaitingLock: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool | None: assert self._inner_lock self.handler.notify_lock_released(self.lock_name, self.lock_key) @@ -296,7 +292,7 @@ class WaitingMultiLock: deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) - _inner_lock_cm: Optional[AsyncContextManager] = None + _inner_lock_cm: AsyncContextManager | None = None _retry_interval: float = 0.1 _lock_span: "opentracing.Scope" = attr.Factory( lambda: start_active_span("WaitingLock.lock") @@ -338,10 +334,10 @@ class WaitingMultiLock: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool | None: assert self._inner_lock_cm for lock_name, lock_key in self.lock_names: diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 272bbc05f9..f13271f302 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -19,7 +19,6 @@ # # import re -from typing import Union from twisted.internet import address, task from twisted.web.client import FileBodyProducer @@ -75,7 +74,7 @@ def _get_requested_host(request: IRequest) -> bytes: return hostname # no Host header, use the address/port that the request arrived on - host: Union[address.IPv4Address, address.IPv6Address] = request.getHost() + host: address.IPv4Address | address.IPv6Address = request.getHost() hostname = host.host.encode("ascii") diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 1a17b8461f..3661a2aeb7 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -18,7 +18,7 @@ # # -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable from twisted.web.server import Request @@ -41,7 +41,7 @@ class AdditionalResource(DirectServeJsonResource): def __init__( self, hs: "HomeServer", - handler: Callable[[Request], Awaitable[Optional[tuple[int, Any]]]], + handler: Callable[[Request], Awaitable[tuple[int, Any] | None]], ): """Initialise AdditionalResource @@ -56,7 +56,7 @@ class AdditionalResource(DirectServeJsonResource): super().__init__(clock=hs.get_clock()) self._handler = handler - async def _async_render(self, request: Request) -> Optional[tuple[int, Any]]: + async def _async_render(self, request: Request) -> tuple[int, Any] | None: # Cheekily pass the result straight through, so we don't need to worry # if its an awaitable or not. return await self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index ff1f7c7128..9971accccd 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -28,9 +28,7 @@ from typing import ( BinaryIO, Callable, Mapping, - Optional, Protocol, - Union, ) import attr @@ -118,7 +116,7 @@ incoming_responses_counter = Counter( # the type of the headers map, to be passed to the t.w.h.Headers. # # The actual type accepted by Twisted is -# Mapping[Union[str, bytes], Sequence[Union[str, bytes]] , +# Mapping[str | bytes], Sequence[str | bytes] , # allowing us to mix and match str and bytes freely. However: any str is also a # Sequence[str]; passing a header string value which is a # standalone str is interpreted as a sequence of 1-codepoint strings. This is a disastrous footgun. @@ -126,21 +124,21 @@ incoming_responses_counter = Counter( # # We also simplify the keys to be either all str or all bytes. This helps because # Dict[K, V] is invariant in K (and indeed V). -RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]] +RawHeaders = Mapping[str, "RawHeaderValue"] | Mapping[bytes, "RawHeaderValue"] # the value actually has to be a List, but List is invariant so we can't specify that # the entries can either be Lists or bytes. -RawHeaderValue = Union[ - StrSequence, - list[bytes], - list[Union[str, bytes]], - tuple[bytes, ...], - tuple[Union[str, bytes], ...], -] +RawHeaderValue = ( + StrSequence + | list[bytes] + | list[str | bytes] + | tuple[bytes, ...] + | tuple[str | bytes, ...] +) def _is_ip_blocked( - ip_address: IPAddress, allowlist: Optional[IPSet], blocklist: IPSet + ip_address: IPAddress, allowlist: IPSet | None, blocklist: IPSet ) -> bool: """ Compares an IP address to allowed and disallowed IP sets. @@ -186,7 +184,7 @@ class _IPBlockingResolver: def __init__( self, reactor: IReactorPluggableNameResolver, - ip_allowlist: Optional[IPSet], + ip_allowlist: IPSet | None, ip_blocklist: IPSet, ): """ @@ -262,7 +260,7 @@ class BlocklistingReactorWrapper: def __init__( self, reactor: IReactorPluggableNameResolver, - ip_allowlist: Optional[IPSet], + ip_allowlist: IPSet | None, ip_blocklist: IPSet, ): self._reactor = reactor @@ -291,7 +289,7 @@ class BlocklistingAgentWrapper(Agent): self, agent: IAgent, ip_blocklist: IPSet, - ip_allowlist: Optional[IPSet] = None, + ip_allowlist: IPSet | None = None, ): """ Args: @@ -307,13 +305,13 @@ class BlocklistingAgentWrapper(Agent): self, method: bytes, uri: bytes, - headers: Optional[Headers] = None, - bodyProducer: Optional[IBodyProducer] = None, + headers: Headers | None = None, + bodyProducer: IBodyProducer | None = None, ) -> defer.Deferred: h = urllib.parse.urlparse(uri.decode("ascii")) try: - # h.hostname is Optional[str], None raises an AddrFormatError, so + # h.hostname is str | None, None raises an AddrFormatError, so # this is safe even though IPAddress requires a str. ip_address = IPAddress(h.hostname) # type: ignore[arg-type] except AddrFormatError: @@ -346,7 +344,7 @@ class BaseHttpClient: def __init__( self, hs: "HomeServer", - treq_args: Optional[dict[str, Any]] = None, + treq_args: dict[str, Any] | None = None, ): self.hs = hs self.server_name = hs.hostname @@ -371,8 +369,8 @@ class BaseHttpClient: self, method: str, uri: str, - data: Optional[bytes] = None, - headers: Optional[Headers] = None, + data: bytes | None = None, + headers: Headers | None = None, ) -> IResponse: """ Args: @@ -476,8 +474,8 @@ class BaseHttpClient: async def post_urlencoded_get_json( self, uri: str, - args: Optional[Mapping[str, Union[str, list[str]]]] = None, - headers: Optional[RawHeaders] = None, + args: Mapping[str, str | list[str]] | None = None, + headers: RawHeaders | None = None, ) -> Any: """ Args: @@ -525,7 +523,7 @@ class BaseHttpClient: ) async def post_json_get_json( - self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None + self, uri: str, post_json: Any, headers: RawHeaders | None = None ) -> Any: """ @@ -574,8 +572,8 @@ class BaseHttpClient: async def get_json( self, uri: str, - args: Optional[QueryParams] = None, - headers: Optional[RawHeaders] = None, + args: QueryParams | None = None, + headers: RawHeaders | None = None, ) -> Any: """Gets some json from the given URI. @@ -605,8 +603,8 @@ class BaseHttpClient: self, uri: str, json_body: Any, - args: Optional[QueryParams] = None, - headers: Optional[RawHeaders] = None, + args: QueryParams | None = None, + headers: RawHeaders | None = None, ) -> Any: """Puts some json to the given URI. @@ -656,8 +654,8 @@ class BaseHttpClient: async def get_raw( self, uri: str, - args: Optional[QueryParams] = None, - headers: Optional[RawHeaders] = None, + args: QueryParams | None = None, + headers: RawHeaders | None = None, ) -> bytes: """Gets raw text from the given URI. @@ -701,9 +699,9 @@ class BaseHttpClient: self, url: str, output_stream: BinaryIO, - max_size: Optional[int] = None, - headers: Optional[RawHeaders] = None, - is_allowed_content_type: Optional[Callable[[str], bool]] = None, + max_size: int | None = None, + headers: RawHeaders | None = None, + is_allowed_content_type: Callable[[str], bool] | None = None, ) -> tuple[int, dict[bytes, list[bytes]], str, int]: """GETs a file from a given URL Args: @@ -812,9 +810,9 @@ class SimpleHttpClient(BaseHttpClient): def __init__( self, hs: "HomeServer", - treq_args: Optional[dict[str, Any]] = None, - ip_allowlist: Optional[IPSet] = None, - ip_blocklist: Optional[IPSet] = None, + treq_args: dict[str, Any] | None = None, + ip_allowlist: IPSet | None = None, + ip_blocklist: IPSet | None = None, use_proxy: bool = False, ): super().__init__(hs, treq_args=treq_args) @@ -891,8 +889,8 @@ class ReplicationClient(BaseHttpClient): self, method: str, uri: str, - data: Optional[bytes] = None, - headers: Optional[Headers] = None, + data: bytes | None = None, + headers: Headers | None = None, ) -> IResponse: """ Make a request, differs from BaseHttpClient.request in that it does not use treq. @@ -1028,7 +1026,7 @@ class BodyExceededMaxSize(Exception): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which immediately errors upon receiving data.""" - transport: Optional[ITCPTransport] = None + transport: ITCPTransport | None = None def __init__(self, deferred: defer.Deferred): self.deferred = deferred @@ -1058,10 +1056,10 @@ class MultipartResponse: """ json: bytes = b"{}" - length: Optional[int] = None - content_type: Optional[bytes] = None - disposition: Optional[bytes] = None - url: Optional[bytes] = None + length: int | None = None + content_type: bytes | None = None + disposition: bytes | None = None + url: bytes | None = None class _MultipartParserProtocol(protocol.Protocol): @@ -1069,20 +1067,20 @@ class _MultipartParserProtocol(protocol.Protocol): Protocol to read and parse a MSC3916 multipart/mixed response """ - transport: Optional[ITCPTransport] = None + transport: ITCPTransport | None = None def __init__( self, stream: ByteWriteable, deferred: defer.Deferred, boundary: str, - max_length: Optional[int], + max_length: int | None, ) -> None: self.stream = stream self.deferred = deferred self.boundary = boundary self.max_length = max_length - self.parser: Optional[MultipartParser] = None + self.parser: MultipartParser | None = None self.multipart_response = MultipartResponse() self.has_redirect = False self.in_json = False @@ -1177,10 +1175,10 @@ class _MultipartParserProtocol(protocol.Protocol): class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" - transport: Optional[ITCPTransport] = None + transport: ITCPTransport | None = None def __init__( - self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int] + self, stream: ByteWriteable, deferred: defer.Deferred, max_size: int | None ): self.stream = stream self.deferred = deferred @@ -1230,7 +1228,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): def read_body_with_max_size( - response: IResponse, stream: ByteWriteable, max_size: Optional[int] + response: IResponse, stream: ByteWriteable, max_size: int | None ) -> "defer.Deferred[int]": """ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. @@ -1260,7 +1258,7 @@ def read_body_with_max_size( def read_multipart_response( - response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int] + response: IResponse, stream: ByteWriteable, boundary: str, max_length: int | None ) -> "defer.Deferred[MultipartResponse]": """ Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into @@ -1285,7 +1283,7 @@ def read_multipart_response( return d -def encode_query_args(args: Optional[QueryParams]) -> bytes: +def encode_query_args(args: QueryParams | None) -> bytes: """ Encodes a map of query arguments to bytes which can be appended to a URL. @@ -1323,7 +1321,7 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): def is_unknown_endpoint( - e: HttpResponseException, synapse_error: Optional[SynapseError] = None + e: HttpResponseException, synapse_error: SynapseError | None = None ) -> bool: """ Returns true if the response was due to an endpoint being unimplemented. diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index db803bc75a..094655f91a 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -22,7 +22,7 @@ import abc import base64 import logging -from typing import Optional, Union +from typing import Union import attr from zope.interface import implementer @@ -106,7 +106,7 @@ class HTTPConnectProxyEndpoint: proxy_endpoint: IStreamClientEndpoint, host: bytes, port: int, - proxy_creds: Optional[ProxyCredentials], + proxy_creds: ProxyCredentials | None, ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint @@ -146,7 +146,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): dst_host: bytes, dst_port: int, wrapped_factory: IProtocolFactory, - proxy_creds: Optional[ProxyCredentials], + proxy_creds: ProxyCredentials | None, ): self.dst_host = dst_host self.dst_port = dst_port @@ -212,7 +212,7 @@ class HTTPConnectProtocol(protocol.Protocol): port: int, wrapped_protocol: IProtocol, connected_deferred: defer.Deferred, - proxy_creds: Optional[ProxyCredentials], + proxy_creds: ProxyCredentials | None, ): self.host = host self.port = port @@ -275,7 +275,7 @@ class HTTPConnectSetupClient(http.HTTPClient): self, host: bytes, port: int, - proxy_creds: Optional[ProxyCredentials], + proxy_creds: ProxyCredentials | None, ): self.host = host self.port = port diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index f8482d9c48..c3ba26fe03 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -19,7 +19,7 @@ # import logging import urllib.parse -from typing import Any, Generator, Optional +from typing import Any, Generator from urllib.request import ( # type: ignore[attr-defined] proxy_bypass_environment, ) @@ -101,13 +101,13 @@ class MatrixFederationAgent: server_name: str, reactor: ISynapseReactor, clock: Clock, - tls_client_options_factory: Optional[FederationPolicyForHTTPS], + tls_client_options_factory: FederationPolicyForHTTPS | None, user_agent: bytes, - ip_allowlist: Optional[IPSet], + ip_allowlist: IPSet | None, ip_blocklist: IPSet, - proxy_config: Optional[ProxyConfig] = None, - _srv_resolver: Optional[SrvResolver] = None, - _well_known_resolver: Optional[WellKnownResolver] = None, + proxy_config: ProxyConfig | None = None, + _srv_resolver: SrvResolver | None = None, + _well_known_resolver: WellKnownResolver | None = None, ): """ Args: @@ -172,8 +172,8 @@ class MatrixFederationAgent: self, method: bytes, uri: bytes, - headers: Optional[Headers] = None, - bodyProducer: Optional[IBodyProducer] = None, + headers: Headers | None = None, + bodyProducer: IBodyProducer | None = None, ) -> Generator[defer.Deferred, Any, IResponse]: """ Args: @@ -259,9 +259,9 @@ class MatrixHostnameEndpointFactory: *, reactor: IReactorCore, proxy_reactor: IReactorCore, - tls_client_options_factory: Optional[FederationPolicyForHTTPS], - srv_resolver: Optional[SrvResolver], - proxy_config: Optional[ProxyConfig], + tls_client_options_factory: FederationPolicyForHTTPS | None, + srv_resolver: SrvResolver | None, + proxy_config: ProxyConfig | None, ): self._reactor = reactor self._proxy_reactor = proxy_reactor @@ -310,9 +310,9 @@ class MatrixHostnameEndpoint: *, reactor: IReactorCore, proxy_reactor: IReactorCore, - tls_client_options_factory: Optional[FederationPolicyForHTTPS], + tls_client_options_factory: FederationPolicyForHTTPS | None, srv_resolver: SrvResolver, - proxy_config: Optional[ProxyConfig], + proxy_config: ProxyConfig | None, parsed_uri: URI, ): self._reactor = reactor diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index ac4d954c2c..ec72e178c9 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -22,7 +22,7 @@ import logging import random import time from io import BytesIO -from typing import Callable, Optional +from typing import Callable import attr @@ -80,7 +80,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class WellKnownLookupResult: - delegated_server: Optional[bytes] + delegated_server: bytes | None class WellKnownResolver: @@ -93,8 +93,8 @@ class WellKnownResolver: clock: Clock, agent: IAgent, user_agent: bytes, - well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None, - had_well_known_cache: Optional[TTLCache[bytes, bool]] = None, + well_known_cache: TTLCache[bytes, bytes | None] | None = None, + had_well_known_cache: TTLCache[bytes, bool] | None = None, ): """ Args: @@ -156,7 +156,7 @@ class WellKnownResolver: # label metrics) server_name=self.server_name, ): - result: Optional[bytes] + result: bytes | None cache_period: float result, cache_period = await self._fetch_well_known(server_name) @@ -320,7 +320,7 @@ class WellKnownResolver: def _cache_period_from_headers( headers: Headers, time_now: Callable[[], float] = time.time -) -> Optional[float]: +) -> float | None: cache_controls = _parse_cache_control(headers) if b"no-store" in cache_controls: @@ -348,7 +348,7 @@ def _cache_period_from_headers( return None -def _parse_cache_control(headers: Headers) -> dict[bytes, Optional[bytes]]: +def _parse_cache_control(headers: Headers) -> dict[bytes, bytes | None]: cache_controls = {} cache_control_headers = headers.getRawHeaders(b"cache-control") or [] for hdr in cache_control_headers: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 562007c74f..7090960cfb 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -33,10 +33,8 @@ from typing import ( Callable, Generic, Literal, - Optional, TextIO, TypeVar, - Union, cast, overload, ) @@ -153,15 +151,15 @@ class MatrixFederationRequest: """The remote server to send the HTTP request to. """ - json: Optional[JsonDict] = None + json: JsonDict | None = None """JSON to send in the body. """ - json_callback: Optional[Callable[[], JsonDict]] = None + json_callback: Callable[[], JsonDict] | None = None """A callback to generate the JSON. """ - query: Optional[QueryParams] = None + query: QueryParams | None = None """Query arguments. """ @@ -204,7 +202,7 @@ class MatrixFederationRequest: ) object.__setattr__(self, "uri", uri) - def get_json(self) -> Optional[JsonDict]: + def get_json(self) -> JsonDict | None: if self.json_callback: return self.json_callback() return self.json @@ -216,7 +214,7 @@ class _BaseJsonParser(ByteParser[T]): CONTENT_TYPE = "application/json" def __init__( - self, validator: Optional[Callable[[Optional[object]], bool]] = None + self, validator: Callable[[object | None], bool] | None = None ) -> None: """ Args: @@ -390,7 +388,7 @@ class BinaryIOWrapper: self.decoder = codecs.getincrementaldecoder(encoding)(errors) self.file = file - def write(self, b: Union[bytes, bytearray]) -> int: + def write(self, b: bytes | bytearray) -> int: self.file.write(self.decoder.decode(b)) return len(b) @@ -407,7 +405,7 @@ class MatrixFederationHttpClient: def __init__( self, hs: "HomeServer", - tls_client_options_factory: Optional[FederationPolicyForHTTPS], + tls_client_options_factory: FederationPolicyForHTTPS | None, ): self.hs = hs self.signing_key = hs.signing_key @@ -550,7 +548,7 @@ class MatrixFederationHttpClient: self, request: MatrixFederationRequest, retry_on_dns_fail: bool = True, - timeout: Optional[int] = None, + timeout: int | None = None, long_retries: bool = False, ignore_backoff: bool = False, backoff_on_404: bool = False, @@ -693,7 +691,7 @@ class MatrixFederationHttpClient: destination_bytes, method_bytes, url_to_sign_bytes, json ) data = encode_canonical_json(json) - producer: Optional[IBodyProducer] = QuieterFileBodyProducer( + producer: IBodyProducer | None = QuieterFileBodyProducer( BytesIO(data), cooperator=self._cooperator ) else: @@ -905,11 +903,11 @@ class MatrixFederationHttpClient: def build_auth_headers( self, - destination: Optional[bytes], + destination: bytes | None, method: bytes, url_bytes: bytes, - content: Optional[JsonDict] = None, - destination_is: Optional[bytes] = None, + content: JsonDict | None = None, + destination_is: bytes | None = None, ) -> list[bytes]: """ Builds the Authorization headers for a federation request @@ -970,11 +968,11 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = None, - data: Optional[JsonDict] = None, - json_data_callback: Optional[Callable[[], JsonDict]] = None, + args: QueryParams | None = None, + data: JsonDict | None = None, + json_data_callback: Callable[[], JsonDict] | None = None, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, @@ -987,15 +985,15 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = None, - data: Optional[JsonDict] = None, - json_data_callback: Optional[Callable[[], JsonDict]] = None, + args: QueryParams | None = None, + data: JsonDict | None = None, + json_data_callback: Callable[[], JsonDict] | None = None, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser[T]] = None, + parser: ByteParser[T] | None = None, backoff_on_all_error_codes: bool = False, ) -> T: ... @@ -1003,17 +1001,17 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = None, - data: Optional[JsonDict] = None, - json_data_callback: Optional[Callable[[], JsonDict]] = None, + args: QueryParams | None = None, + data: JsonDict | None = None, + json_data_callback: Callable[[], JsonDict] | None = None, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser[T]] = None, + parser: ByteParser[T] | None = None, backoff_on_all_error_codes: bool = False, - ) -> Union[JsonDict, T]: + ) -> JsonDict | T: """Sends the specified json data using PUT Args: @@ -1109,11 +1107,11 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - data: Optional[JsonDict] = None, + data: JsonDict | None = None, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, ) -> JsonDict: """Sends the specified json data using POST @@ -1188,9 +1186,9 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, @@ -1201,9 +1199,9 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = ..., + args: QueryParams | None = ..., retry_on_dns_fail: bool = ..., - timeout: Optional[int] = ..., + timeout: int | None = ..., ignore_backoff: bool = ..., try_trailing_slash_on_400: bool = ..., parser: ByteParser[T] = ..., @@ -1213,13 +1211,13 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser[T]] = None, - ) -> Union[JsonDict, T]: + parser: ByteParser[T] | None = None, + ) -> JsonDict | T: """GETs some json from the given host homeserver and path Args: @@ -1282,9 +1280,9 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, @@ -1295,9 +1293,9 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = ..., + args: QueryParams | None = ..., retry_on_dns_fail: bool = ..., - timeout: Optional[int] = ..., + timeout: int | None = ..., ignore_backoff: bool = ..., try_trailing_slash_on_400: bool = ..., parser: ByteParser[T] = ..., @@ -1307,13 +1305,13 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser[T]] = None, - ) -> tuple[Union[JsonDict, T], dict[bytes, list[bytes]]]: + parser: ByteParser[T] | None = None, + ) -> tuple[JsonDict | T, dict[bytes, list[bytes]]]: """GETs some json from the given host homeserver and path Args: @@ -1401,9 +1399,9 @@ class MatrixFederationHttpClient: destination: str, path: str, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, ) -> JsonDict: """Send a DELETE request to the remote expecting some json response @@ -1477,7 +1475,7 @@ class MatrixFederationHttpClient: download_ratelimiter: Ratelimiter, ip_address: str, max_size: int, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, ignore_backoff: bool = False, follow_redirects: bool = False, @@ -1639,7 +1637,7 @@ class MatrixFederationHttpClient: download_ratelimiter: Ratelimiter, ip_address: str, max_size: int, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, ignore_backoff: bool = False, ) -> tuple[int, dict[bytes, list[bytes]], bytes]: diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py index 583dd092bd..c7f5e39dd8 100644 --- a/synapse/http/proxy.py +++ b/synapse/http/proxy.py @@ -22,7 +22,7 @@ import json import logging import urllib.parse -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from twisted.internet import protocol from twisted.internet.interfaces import ITCPTransport @@ -65,7 +65,7 @@ assert all(header.lower() == header for header in HOP_BY_HOP_HEADERS_LOWERCASE) def parse_connection_header_value( - connection_header_value: Optional[bytes], + connection_header_value: bytes | None, ) -> set[str]: """ Parse the `Connection` header to determine which headers we should not be copied @@ -237,7 +237,7 @@ class _ProxyResponseBody(protocol.Protocol): request. """ - transport: Optional[ITCPTransport] = None + transport: ITCPTransport | None = None def __init__(self, request: "SynapseRequest") -> None: self._request = request diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 67e04b18d9..d315ce8475 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -21,7 +21,7 @@ import logging import random import re -from typing import Any, Collection, Optional, Sequence, Union, cast +from typing import Any, Collection, Sequence, cast from urllib.parse import urlparse from urllib.request import ( # type: ignore[attr-defined] proxy_bypass_environment, @@ -119,14 +119,14 @@ class ProxyAgent(_AgentBase): self, *, reactor: IReactorCore, - proxy_reactor: Optional[IReactorCore] = None, - contextFactory: Optional[IPolicyForHTTPS] = None, - connectTimeout: Optional[float] = None, - bindAddress: Optional[bytes] = None, - pool: Optional[HTTPConnectionPool] = None, - proxy_config: Optional[ProxyConfig] = None, + proxy_reactor: IReactorCore | None = None, + contextFactory: IPolicyForHTTPS | None = None, + connectTimeout: float | None = None, + bindAddress: bytes | None = None, + pool: HTTPConnectionPool | None = None, + proxy_config: ProxyConfig | None = None, federation_proxy_locations: Collection[InstanceLocationConfig] = (), - federation_proxy_credentials: Optional[ProxyCredentials] = None, + federation_proxy_credentials: ProxyCredentials | None = None, ): contextFactory = contextFactory or BrowserLikePolicyForHTTPS() @@ -175,8 +175,8 @@ class ProxyAgent(_AgentBase): self._policy_for_https = contextFactory self._reactor = cast(IReactorTime, reactor) - self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None - self._federation_proxy_credentials: Optional[ProxyCredentials] = None + self._federation_proxy_endpoint: IStreamClientEndpoint | None = None + self._federation_proxy_credentials: ProxyCredentials | None = None if federation_proxy_locations: assert federation_proxy_credentials is not None, ( "`federation_proxy_credentials` are required when using `federation_proxy_locations`" @@ -220,8 +220,8 @@ class ProxyAgent(_AgentBase): self, method: bytes, uri: bytes, - headers: Optional[Headers] = None, - bodyProducer: Optional[IBodyProducer] = None, + headers: Headers | None = None, + bodyProducer: IBodyProducer | None = None, ) -> "defer.Deferred[IResponse]": """ Issue a request to the server indicated by the given uri. @@ -363,13 +363,13 @@ class ProxyAgent(_AgentBase): def http_proxy_endpoint( - proxy: Optional[bytes], + proxy: bytes | None, reactor: IReactorCore, - tls_options_factory: Optional[IPolicyForHTTPS], + tls_options_factory: IPolicyForHTTPS | None, timeout: float = 30, - bindAddress: Optional[Union[bytes, str, tuple[Union[bytes, str], int]]] = None, - attemptDelay: Optional[float] = None, -) -> tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: + bindAddress: bytes | str | tuple[bytes | str, int] | None = None, + attemptDelay: float | None = None, +) -> tuple[IStreamClientEndpoint | None, ProxyCredentials | None]: """Parses an http proxy setting and returns an endpoint for the proxy Args: @@ -418,7 +418,7 @@ def http_proxy_endpoint( def parse_proxy( proxy: bytes, default_scheme: bytes = b"http", default_port: int = 1080 -) -> tuple[bytes, bytes, int, Optional[ProxyCredentials]]: +) -> tuple[bytes, bytes, int, ProxyCredentials | None]: """ Parse a proxy connection string. diff --git a/synapse/http/replicationagent.py b/synapse/http/replicationagent.py index f4799bd1b2..708e4c386b 100644 --- a/synapse/http/replicationagent.py +++ b/synapse/http/replicationagent.py @@ -20,7 +20,6 @@ # import logging -from typing import Optional from zope.interface import implementer @@ -119,9 +118,9 @@ class ReplicationAgent(_AgentBase): reactor: ISynapseReactor, instance_map: dict[str, InstanceLocationConfig], contextFactory: IPolicyForHTTPS, - connectTimeout: Optional[float] = None, - bindAddress: Optional[bytes] = None, - pool: Optional[HTTPConnectionPool] = None, + connectTimeout: float | None = None, + bindAddress: bytes | None = None, + pool: HTTPConnectionPool | None = None, ): """ Create a ReplicationAgent. @@ -149,8 +148,8 @@ class ReplicationAgent(_AgentBase): self, method: bytes, uri: bytes, - headers: Optional[Headers] = None, - bodyProducer: Optional[IBodyProducer] = None, + headers: Headers | None = None, + bodyProducer: IBodyProducer | None = None, ) -> "defer.Deferred[IResponse]": """ Issue a request to the server indicated by the given uri. diff --git a/synapse/http/server.py b/synapse/http/server.py index 1f4728fba2..5f4e7484fd 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -35,10 +35,8 @@ from typing import ( Callable, Iterable, Iterator, - Optional, Pattern, Protocol, - Union, cast, ) @@ -111,7 +109,7 @@ HTTP_STATUS_REQUEST_CANCELLED = 499 def return_json_error( - f: failure.Failure, request: "SynapseRequest", config: Optional[HomeServerConfig] + f: failure.Failure, request: "SynapseRequest", config: HomeServerConfig | None ) -> None: """Sends a JSON error response to clients.""" @@ -173,7 +171,7 @@ def return_json_error( def return_html_error( f: failure.Failure, request: Request, - error_template: Union[str, jinja2.Template], + error_template: str | jinja2.Template, ) -> None: """Sends an HTML error page corresponding to the given failure. @@ -264,7 +262,7 @@ def wrap_async_request_handler( # it is actually called with a SynapseRequest and a kwargs dict for the params, # but I can't figure out how to represent that. ServletCallback = Callable[ - ..., Union[None, Awaitable[None], tuple[int, Any], Awaitable[tuple[int, Any]]] + ..., None | Awaitable[None] | tuple[int, Any] | Awaitable[tuple[int, Any]] ] @@ -349,9 +347,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): f = failure.Failure() self._send_error_response(f, request) - async def _async_render( - self, request: "SynapseRequest" - ) -> Optional[tuple[int, Any]]: + async def _async_render(self, request: "SynapseRequest") -> tuple[int, Any] | None: """Delegates to `_async_render_` methods, or returns a 400 if no appropriate method exists. Can be overridden in sub classes for different routing. @@ -406,7 +402,7 @@ class DirectServeJsonResource(_AsyncResource): canonical_json: bool = False, extract_context: bool = False, # Clock is optional as this class is exposed to the module API. - clock: Optional[Clock] = None, + clock: Clock | None = None, ): """ Args: @@ -603,7 +599,7 @@ class DirectServeHtmlResource(_AsyncResource): self, extract_context: bool = False, # Clock is optional as this class is exposed to the module API. - clock: Optional[Clock] = None, + clock: Clock | None = None, ): """ Args: @@ -732,7 +728,7 @@ class _ByteProducer: request: Request, iterator: Iterator[bytes], ): - self._request: Optional[Request] = request + self._request: Request | None = request self._iterator = iterator self._paused = False self.tracing_scope = start_active_span( @@ -831,7 +827,7 @@ def respond_with_json( json_object: Any, send_cors: bool = False, canonical_json: bool = True, -) -> Optional[int]: +) -> int | None: """Sends encoded JSON in response to the given request. Args: @@ -880,7 +876,7 @@ def respond_with_json_bytes( code: int, json_bytes: bytes, send_cors: bool = False, -) -> Optional[int]: +) -> int | None: """Sends encoded JSON in response to the given request. Args: @@ -929,7 +925,7 @@ async def _async_write_json_to_request_in_thread( expensive. """ - def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes: + def encode(opentracing_span: "opentracing.Span | None") -> bytes: # it might take a while for the threadpool to schedule us, so we write # opentracing logs once we actually get scheduled, so that we can see how # much that contributed. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index bca93fb036..c182497f2d 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -29,7 +29,6 @@ from typing import ( TYPE_CHECKING, Literal, Mapping, - Optional, Sequence, TypeVar, overload, @@ -80,26 +79,26 @@ def parse_integer( @overload -def parse_integer(request: Request, name: str, *, negative: bool) -> Optional[int]: ... +def parse_integer(request: Request, name: str, *, negative: bool) -> int | None: ... @overload def parse_integer( request: Request, name: str, - default: Optional[int] = None, + default: int | None = None, required: bool = False, negative: bool = False, -) -> Optional[int]: ... +) -> int | None: ... def parse_integer( request: Request, name: str, - default: Optional[int] = None, + default: int | None = None, required: bool = False, negative: bool = False, -) -> Optional[int]: +) -> int | None: """Parse an integer parameter from the request string Args: @@ -136,8 +135,8 @@ def parse_integer_from_args( def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[int] = None, -) -> Optional[int]: ... + default: int | None = None, +) -> int | None: ... @overload @@ -153,19 +152,19 @@ def parse_integer_from_args( def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[int] = None, + default: int | None = None, required: bool = False, negative: bool = False, -) -> Optional[int]: ... +) -> int | None: ... def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[int] = None, + default: int | None = None, required: bool = False, negative: bool = False, -) -> Optional[int]: +) -> int | None: """Parse an integer parameter from the request string Args: @@ -217,13 +216,13 @@ def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bo @overload def parse_boolean( - request: Request, name: str, default: Optional[bool] = None, required: bool = False -) -> Optional[bool]: ... + request: Request, name: str, default: bool | None = None, required: bool = False +) -> bool | None: ... def parse_boolean( - request: Request, name: str, default: Optional[bool] = None, required: bool = False -) -> Optional[bool]: + request: Request, name: str, default: bool | None = None, required: bool = False +) -> bool | None: """Parse a boolean parameter from the request query string Args: @@ -265,17 +264,17 @@ def parse_boolean_from_args( def parse_boolean_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[bool] = None, + default: bool | None = None, required: bool = False, -) -> Optional[bool]: ... +) -> bool | None: ... def parse_boolean_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[bool] = None, + default: bool | None = None, required: bool = False, -) -> Optional[bool]: +) -> bool | None: """Parse a boolean parameter from the request query string Args: @@ -318,8 +317,8 @@ def parse_boolean_from_args( def parse_bytes_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[bytes] = None, -) -> Optional[bytes]: ... + default: bytes | None = None, +) -> bytes | None: ... @overload @@ -336,17 +335,17 @@ def parse_bytes_from_args( def parse_bytes_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[bytes] = None, + default: bytes | None = None, required: bool = False, -) -> Optional[bytes]: ... +) -> bytes | None: ... def parse_bytes_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[bytes] = None, + default: bytes | None = None, required: bool = False, -) -> Optional[bytes]: +) -> bytes | None: """ Parse a string parameter as bytes from the request query string. @@ -380,7 +379,7 @@ def parse_string( name: str, default: str, *, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", ) -> str: ... @@ -391,7 +390,7 @@ def parse_string( name: str, *, required: Literal[True], - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", ) -> str: ... @@ -401,21 +400,21 @@ def parse_string( request: Request, name: str, *, - default: Optional[str] = None, + default: str | None = None, required: bool = False, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[str]: ... +) -> str | None: ... def parse_string( request: Request, name: str, - default: Optional[str] = None, + default: str | None = None, required: bool = False, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[str]: +) -> str | None: """ Parse a string parameter from the request query string. @@ -455,10 +454,10 @@ def parse_string( def parse_json( request: Request, name: str, - default: Optional[dict] = None, + default: dict | None = None, required: bool = False, encoding: str = "ascii", -) -> Optional[JsonDict]: +) -> JsonDict | None: """ Parse a JSON parameter from the request query string. @@ -492,10 +491,10 @@ def parse_json( def parse_json_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[dict] = None, + default: dict | None = None, required: bool = False, encoding: str = "ascii", -) -> Optional[JsonDict]: +) -> JsonDict | None: """ Parse a JSON parameter from the request query string. @@ -559,9 +558,9 @@ def parse_enum( request: Request, name: str, E: type[EnumT], - default: Optional[EnumT] = None, + default: EnumT | None = None, required: bool = False, -) -> Optional[EnumT]: +) -> EnumT | None: """ Parse an enum parameter from the request query string. @@ -601,7 +600,7 @@ def parse_enum( def _parse_string_value( value: bytes, - allowed_values: Optional[StrCollection], + allowed_values: StrCollection | None, name: str, encoding: str, ) -> str: @@ -627,9 +626,9 @@ def parse_strings_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, *, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[list[str]]: ... +) -> list[str] | None: ... @overload @@ -638,7 +637,7 @@ def parse_strings_from_args( name: str, default: list[str], *, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", ) -> list[str]: ... @@ -649,7 +648,7 @@ def parse_strings_from_args( name: str, *, required: Literal[True], - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", ) -> list[str]: ... @@ -658,22 +657,22 @@ def parse_strings_from_args( def parse_strings_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[list[str]] = None, + default: list[str] | None = None, *, required: bool = False, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[list[str]]: ... +) -> list[str] | None: ... def parse_strings_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[list[str]] = None, + default: list[str] | None = None, required: bool = False, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[list[str]]: +) -> list[str] | None: """ Parse a string parameter from the request query string list. @@ -720,21 +719,21 @@ def parse_strings_from_args( def parse_string_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[str] = None, + default: str | None = None, *, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[str]: ... +) -> str | None: ... @overload def parse_string_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[str] = None, + default: str | None = None, *, required: Literal[True], - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", ) -> str: ... @@ -743,21 +742,21 @@ def parse_string_from_args( def parse_string_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[str] = None, + default: str | None = None, required: bool = False, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[str]: ... +) -> str | None: ... def parse_string_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[str] = None, + default: str | None = None, required: bool = False, - allowed_values: Optional[StrCollection] = None, + allowed_values: StrCollection | None = None, encoding: str = "ascii", -) -> Optional[str]: +) -> str | None: """ Parse the string parameter from the request query string list and return the first result. @@ -812,12 +811,12 @@ def parse_json_value_from_request( @overload def parse_json_value_from_request( request: Request, allow_empty_body: bool = False -) -> Optional[JsonDict]: ... +) -> JsonDict | None: ... def parse_json_value_from_request( request: Request, allow_empty_body: bool = False -) -> Optional[JsonDict]: +) -> JsonDict | None: """Parse a JSON value from the body of a twisted HTTP request. Args: @@ -980,8 +979,8 @@ class ResolveRoomIdMixin: self.room_member_handler = hs.get_room_member_handler() async def resolve_room_id( - self, room_identifier: str, remote_room_hosts: Optional[list[str]] = None - ) -> tuple[str, Optional[list[str]]]: + self, room_identifier: str, remote_room_hosts: list[str] | None = None + ) -> tuple[str, list[str] | None]: """ Resolve a room identifier to a room ID, if necessary. diff --git a/synapse/http/site.py b/synapse/http/site.py index ccf6ff27f0..03d5d048b1 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -22,7 +22,7 @@ import contextlib import logging import time from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Generator, Optional, Union +from typing import TYPE_CHECKING, Any, Generator import attr from zope.interface import implementer @@ -88,7 +88,7 @@ class SynapseRequest(Request): our_server_name: str, *args: Any, max_request_body_size: int = 1024, - request_id_header: Optional[str] = None, + request_id_header: str | None = None, **kw: Any, ): super().__init__(channel, *args, **kw) @@ -102,18 +102,18 @@ class SynapseRequest(Request): # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. - self._requester: Optional[Union[Requester, str]] = None + self._requester: Requester | str | None = None # An opentracing span for this request. Will be closed when the request is # completely processed. - self._opentracing_span: "Optional[opentracing.Span]" = None + self._opentracing_span: "opentracing.Span | None" = None # we can't yet create the logcontext, as we don't know the method. - self.logcontext: Optional[LoggingContext] = None + self.logcontext: LoggingContext | None = None # The `Deferred` to cancel if the client disconnects early and # `is_render_cancellable` is set. Expected to be set by `Resource.render`. - self.render_deferred: Optional["Deferred[None]"] = None + self.render_deferred: "Deferred[None]" | None = None # A boolean indicating whether `render_deferred` should be cancelled if the # client disconnects early. Expected to be set by the coroutine started by # `Resource.render`, if rendering is asynchronous. @@ -127,11 +127,11 @@ class SynapseRequest(Request): self._is_processing = False # the time when the asynchronous request handler completed its processing - self._processing_finished_time: Optional[float] = None + self._processing_finished_time: float | None = None # what time we finished sending the response to the client (or the connection # dropped) - self.finish_time: Optional[float] = None + self.finish_time: float | None = None def __repr__(self) -> str: # We overwrite this so that we don't log ``access_token`` @@ -195,11 +195,11 @@ class SynapseRequest(Request): super().handleContentChunk(data) @property - def requester(self) -> Optional[Union[Requester, str]]: + def requester(self) -> Requester | str | None: return self._requester @requester.setter - def requester(self, value: Union[Requester, str]) -> None: + def requester(self, value: Requester | str) -> None: # Store the requester, and update some properties based on it. # This should only be called once. @@ -246,7 +246,7 @@ class SynapseRequest(Request): Returns: The redacted URI as a string. """ - uri: Union[bytes, str] = self.uri + uri: bytes | str = self.uri if isinstance(uri, bytes): uri = uri.decode("ascii", errors="replace") return redact_uri(uri) @@ -261,12 +261,12 @@ class SynapseRequest(Request): Returns: The request method as a string. """ - method: Union[bytes, str] = self.method + method: bytes | str = self.method if isinstance(method, bytes): return self.method.decode("ascii") return method - def get_authenticated_entity(self) -> tuple[Optional[str], Optional[str]]: + def get_authenticated_entity(self) -> tuple[str | None, str | None]: """ Get the "authenticated" entity of the request, which might be the user performing the action, or a user being puppeted by a server admin. @@ -403,7 +403,7 @@ class SynapseRequest(Request): with PreserveLoggingContext(self.logcontext): self._finished_processing() - def connectionLost(self, reason: Union[Failure, Exception]) -> None: + def connectionLost(self, reason: Failure | Exception) -> None: """Called when the client connection is closed before the response is written. Overrides twisted.web.server.Request.connectionLost to record the finish time and @@ -595,7 +595,7 @@ class XForwardedForRequest(SynapseRequest): """ # the client IP and ssl flag, as extracted from the headers. - _forwarded_for: "Optional[_XForwardedForAddress]" = None + _forwarded_for: "_XForwardedForAddress | None" = None _forwarded_https: bool = False def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None: @@ -674,7 +674,7 @@ class SynapseProtocol(HTTPChannel): site: "SynapseSite", our_server_name: str, max_request_body_size: int, - request_id_header: Optional[str], + request_id_header: str | None, request_class: type, ): super().__init__() @@ -821,5 +821,5 @@ class SynapseSite(ProxySite): @attr.s(auto_attribs=True, frozen=True, slots=True) class RequestInfo: - user_agent: Optional[str] + user_agent: str | None ip: str diff --git a/synapse/http/types.py b/synapse/http/types.py index dd954b6c20..a04a285397 100644 --- a/synapse/http/types.py +++ b/synapse/http/types.py @@ -18,10 +18,10 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Iterable, Mapping, Union +from typing import Iterable, Mapping # the type of the query params, to be passed into `urlencode` with `doseq=True`. -QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] -QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] +QueryParamValue = str | bytes | Iterable[str | bytes] +QueryParams = Mapping[str, QueryParamValue] | Mapping[bytes, QueryParamValue] __all__ = ["QueryParams"] diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index a3444221a0..e3e0ba4beb 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -25,7 +25,7 @@ import traceback from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing import Callable, Optional +from typing import Callable import attr from zope.interface import implementer @@ -113,7 +113,7 @@ class RemoteHandler(logging.Handler): port: int, maximum_buffer: int = 1000, level: int = logging.NOTSET, - _reactor: Optional[IReactorTime] = None, + _reactor: IReactorTime | None = None, ): super().__init__(level=level) self.host = host @@ -121,8 +121,8 @@ class RemoteHandler(logging.Handler): self.maximum_buffer = maximum_buffer self._buffer: deque[logging.LogRecord] = deque() - self._connection_waiter: Optional[Deferred] = None - self._producer: Optional[LogProducer] = None + self._connection_waiter: Deferred | None = None + self._producer: LogProducer | None = None # Connect without DNS lookups if it's a direct IP. if _reactor is None: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 919493d1a3..2410d95720 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -40,7 +40,6 @@ from typing import ( Awaitable, Callable, Literal, - Optional, TypeVar, Union, overload, @@ -88,7 +87,7 @@ try: is_thread_resource_usage_supported = True - def get_thread_resource_usage() -> "Optional[resource.struct_rusage]": + def get_thread_resource_usage() -> "resource.struct_rusage | None": return resource.getrusage(RUSAGE_THREAD) except Exception: @@ -96,7 +95,7 @@ except Exception: # won't track resource usage. is_thread_resource_usage_supported = False - def get_thread_resource_usage() -> "Optional[resource.struct_rusage]": + def get_thread_resource_usage() -> "resource.struct_rusage | None": return None @@ -137,7 +136,7 @@ class ContextResourceUsage: "evt_db_fetch_count", ] - def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None: + def __init__(self, copy_from: "ContextResourceUsage | None" = None) -> None: """Create a new ContextResourceUsage Args: @@ -230,8 +229,8 @@ class ContextRequest: request_id: str ip_address: str site_tag: str - requester: Optional[str] - authenticated_entity: Optional[str] + requester: str | None + authenticated_entity: str | None method: str url: str protocol: str @@ -274,10 +273,10 @@ class _Sentinel: def __str__(self) -> str: return "sentinel" - def start(self, rusage: "Optional[resource.struct_rusage]") -> None: + def start(self, rusage: "resource.struct_rusage | None") -> None: pass - def stop(self, rusage: "Optional[resource.struct_rusage]") -> None: + def stop(self, rusage: "resource.struct_rusage | None") -> None: pass def add_database_transaction(self, duration_sec: float) -> None: @@ -334,8 +333,8 @@ class LoggingContext: *, name: str, server_name: str, - parent_context: "Optional[LoggingContext]" = None, - request: Optional[ContextRequest] = None, + parent_context: "LoggingContext | None" = None, + request: ContextRequest | None = None, ) -> None: self.previous_context = current_context() @@ -344,14 +343,14 @@ class LoggingContext: # The thread resource usage when the logcontext became active. None # if the context is not currently active. - self.usage_start: Optional[resource.struct_rusage] = None + self.usage_start: resource.struct_rusage | None = None self.name = name self.server_name = server_name self.main_thread = get_thread_id() self.request = None self.tag = "" - self.scope: Optional["_LogContextScope"] = None + self.scope: "_LogContextScope" | None = None # keep track of whether we have hit the __exit__ block for this context # (suggesting that the the thing that created the context thinks it should @@ -391,9 +390,9 @@ class LoggingContext: def __exit__( self, - type: Optional[type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, ) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. @@ -417,7 +416,7 @@ class LoggingContext: # recorded against the correct metrics. self.finished = True - def start(self, rusage: "Optional[resource.struct_rusage]") -> None: + def start(self, rusage: "resource.struct_rusage | None") -> None: """ Record that this logcontext is currently running. @@ -442,7 +441,7 @@ class LoggingContext: else: self.usage_start = rusage - def stop(self, rusage: "Optional[resource.struct_rusage]") -> None: + def stop(self, rusage: "resource.struct_rusage | None") -> None: """ Record that this logcontext is no longer running. @@ -702,9 +701,9 @@ class PreserveLoggingContext: def __exit__( self, - type: Optional[type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, ) -> None: logcontext_debug_logger.debug( "PreserveLoggingContext(%s).__exit %s --> %s", @@ -823,10 +822,7 @@ def preserve_fn(f: Callable[P, R]) -> Callable[P, "defer.Deferred[R]"]: ... def preserve_fn( - f: Union[ - Callable[P, R], - Callable[P, Awaitable[R]], - ], + f: Callable[P, R] | Callable[P, Awaitable[R]], ) -> Callable[P, "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" @@ -852,10 +848,7 @@ def run_in_background( def run_in_background( - f: Union[ - Callable[P, R], - Callable[P, Awaitable[R]], - ], + f: Callable[P, R] | Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs, ) -> "defer.Deferred[R]": diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index e5d73a47a8..70b6d7f6a1 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -23,7 +23,6 @@ import logging import traceback from io import StringIO from types import TracebackType -from typing import Optional class LogFormatter(logging.Formatter): @@ -39,9 +38,9 @@ class LogFormatter(logging.Formatter): def formatException( self, ei: tuple[ - Optional[type[BaseException]], - Optional[BaseException], - Optional[TracebackType], + type[BaseException] | None, + BaseException | None, + TracebackType | None, ], ) -> str: sio = StringIO() diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index b7945aac72..976c7075d4 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -3,7 +3,7 @@ import time from logging import Handler, LogRecord from logging.handlers import MemoryHandler from threading import Thread -from typing import Optional, cast +from typing import cast from twisted.internet.interfaces import IReactorCore @@ -23,10 +23,10 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): self, capacity: int, flushLevel: int = logging.ERROR, - target: Optional[Handler] = None, + target: Handler | None = None, flushOnClose: bool = True, period: float = 5.0, - reactor: Optional[IReactorCore] = None, + reactor: IReactorCore | None = None, ) -> None: """ period: the period between automatic flushes diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index fbb9971b32..6e4e029163 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -289,7 +289,7 @@ try: except Exception: logger.exception("Failed to report span") - RustReporter: Optional[type[_WrappedRustReporter]] = _WrappedRustReporter + RustReporter: type[_WrappedRustReporter] | None = _WrappedRustReporter except ImportError: RustReporter = None @@ -354,7 +354,7 @@ class SynapseBaggage: # Block everything by default # A regex which matches the server_names to expose traces for. # None means 'block everything'. -_homeserver_whitelist: Optional[Pattern[str]] = None +_homeserver_whitelist: Pattern[str] | None = None # Util methods @@ -370,11 +370,11 @@ R = TypeVar("R") T = TypeVar("T") -def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]: +def only_if_tracing(func: Callable[P, R]) -> Callable[P, R | None]: """Executes the function only if we're tracing. Otherwise returns None.""" @wraps(func) - def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: + def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> R | None: if opentracing: return func(*args, **kwargs) else: @@ -386,18 +386,18 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]: @overload def ensure_active_span( message: str, -) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]: ... +) -> Callable[[Callable[P, R]], Callable[P, R | None]]: ... @overload def ensure_active_span( message: str, ret: T -) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]: ... +) -> Callable[[Callable[P, R]], Callable[P, T | R]]: ... def ensure_active_span( - message: str, ret: Optional[T] = None -) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]: + message: str, ret: T | None = None +) -> Callable[[Callable[P, R]], Callable[P, T | None | R]]: """Executes the operation only if opentracing is enabled and there is an active span. If there is no active span it logs message at the error level. @@ -413,11 +413,11 @@ def ensure_active_span( def ensure_active_span_inner_1( func: Callable[P, R], - ) -> Callable[P, Union[Optional[T], R]]: + ) -> Callable[P, T | None | R]: @wraps(func) def ensure_active_span_inner_2( *args: P.args, **kwargs: P.kwargs - ) -> Union[Optional[T], R]: + ) -> T | None | R: if not opentracing: return ret @@ -532,10 +532,10 @@ def whitelisted_homeserver(destination: str) -> bool: # Could use kwargs but I want these to be explicit def start_active_span( operation_name: str, - child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None, - references: Optional[list["opentracing.Reference"]] = None, - tags: Optional[dict[str, str]] = None, - start_time: Optional[float] = None, + child_of: Union["opentracing.Span", "opentracing.SpanContext"] | None = None, + references: list["opentracing.Reference"] | None = None, + tags: dict[str, str] | None = None, + start_time: float | None = None, ignore_active_span: bool = False, finish_on_close: bool = True, *, @@ -573,9 +573,9 @@ def start_active_span( def start_active_span_follows_from( operation_name: str, contexts: Collection, - child_of: Optional[Union["opentracing.Span", "opentracing.SpanContext"]] = None, - tags: Optional[dict[str, str]] = None, - start_time: Optional[float] = None, + child_of: Union["opentracing.Span", "opentracing.SpanContext"] | None = None, + tags: dict[str, str] | None = None, + start_time: float | None = None, ignore_active_span: bool = False, *, inherit_force_tracing: bool = False, @@ -630,9 +630,9 @@ def start_active_span_follows_from( def start_active_span_from_edu( edu_content: dict[str, Any], operation_name: str, - references: Optional[list["opentracing.Reference"]] = None, - tags: Optional[dict[str, str]] = None, - start_time: Optional[float] = None, + references: list["opentracing.Reference"] | None = None, + tags: dict[str, str] | None = None, + start_time: float | None = None, ignore_active_span: bool = False, finish_on_close: bool = True, ) -> "opentracing.Scope": @@ -699,14 +699,14 @@ def active_span( @ensure_active_span("set a tag") -def set_tag(key: str, value: Union[str, bool, int, float]) -> None: +def set_tag(key: str, value: str | bool | int | float) -> None: """Sets a tag on the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_tag(key, value) @ensure_active_span("log") -def log_kv(key_values: dict[str, Any], timestamp: Optional[float] = None) -> None: +def log_kv(key_values: dict[str, Any], timestamp: float | None = None) -> None: """Log to the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.log_kv(key_values, timestamp) @@ -758,7 +758,7 @@ def is_context_forced_tracing( @ensure_active_span("inject the span into a header dict") def inject_header_dict( headers: dict[bytes, list[bytes]], - destination: Optional[str] = None, + destination: str | None = None, check_destination: bool = True, ) -> None: """ @@ -826,7 +826,7 @@ def inject_request_headers(headers: dict[str, str]) -> None: @ensure_active_span( "get the active span context as a dict", ret=cast(dict[str, str], {}) ) -def get_active_span_text_map(destination: Optional[str] = None) -> dict[str, str]: +def get_active_span_text_map(destination: str | None = None) -> dict[str, str]: """ Gets a span context as a dict. This can be used instead of manually injecting a span into an empty carrier. @@ -865,7 +865,7 @@ def active_span_context_as_string() -> str: return json_encoder.encode(carrier) -def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]": +def span_context_from_request(request: Request) -> "opentracing.SpanContext | None": """Extract an opentracing context from the headers on an HTTP request This is useful when we have received an HTTP request from another part of our @@ -1119,7 +1119,7 @@ def trace_servlet( # with JsonResource). scope.span.set_operation_name(request.request_metrics.name) - # Mypy seems to think that start_context.tag below can be Optional[str], but + # Mypy seems to think that start_context.tag below can be str | None, but # that doesn't appear to be correct and works in practice. request_tags[SynapseTags.REQUEST_TAG] = ( diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index feaadc4d87..f3ec07eecf 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -20,7 +20,6 @@ # import logging -from typing import Optional from opentracing import Scope, ScopeManager, Span @@ -47,7 +46,7 @@ class LogContextScopeManager(ScopeManager): pass @property - def active(self) -> Optional[Scope]: + def active(self) -> Scope | None: """ Returns the currently active Scope which can be used to access the currently active Scope.span. diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 319ca662e2..e0313d2893 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -30,7 +30,6 @@ from typing import ( Awaitable, BinaryIO, Generator, - Optional, ) import attr @@ -133,8 +132,8 @@ async def respond_with_file( request: SynapseRequest, media_type: str, file_path: str, - file_size: Optional[int] = None, - upload_name: Optional[str] = None, + file_size: int | None = None, + upload_name: str | None = None, ) -> None: logger.debug("Responding with %r", file_path) @@ -156,8 +155,8 @@ async def respond_with_file( def add_file_headers( request: Request, media_type: str, - file_size: Optional[int], - upload_name: Optional[str], + file_size: int | None, + upload_name: str | None, ) -> None: """Adds the correct response headers in preparation for responding with the media. @@ -301,10 +300,10 @@ def _can_encode_filename_as_token(x: str) -> bool: async def respond_with_multipart_responder( clock: Clock, request: SynapseRequest, - responder: "Optional[Responder]", + responder: "Responder | None", media_type: str, - media_length: Optional[int], - upload_name: Optional[str], + media_length: int | None, + upload_name: str | None, ) -> None: """ Responds to requests originating from the federation media `/download` endpoint by @@ -392,10 +391,10 @@ async def respond_with_multipart_responder( async def respond_with_responder( request: SynapseRequest, - responder: "Optional[Responder]", + responder: "Responder | None", media_type: str, - file_size: Optional[int], - upload_name: Optional[str] = None, + file_size: int | None, + upload_name: str | None = None, ) -> None: """Responds to the request with given responder. If responder is None then returns 404. @@ -501,9 +500,9 @@ class Responder(ABC): def __exit__( # noqa: B027 self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: pass @@ -526,47 +525,47 @@ class FileInfo: """Details about a requested/uploaded file.""" # The server name where the media originated from, or None if local. - server_name: Optional[str] + server_name: str | None # The local ID of the file. For local files this is the same as the media_id file_id: str # If the file is for the url preview cache url_cache: bool = False # Whether the file is a thumbnail or not. - thumbnail: Optional[ThumbnailInfo] = None + thumbnail: ThumbnailInfo | None = None # The below properties exist to maintain compatibility with third-party modules. @property - def thumbnail_width(self) -> Optional[int]: + def thumbnail_width(self) -> int | None: if not self.thumbnail: return None return self.thumbnail.width @property - def thumbnail_height(self) -> Optional[int]: + def thumbnail_height(self) -> int | None: if not self.thumbnail: return None return self.thumbnail.height @property - def thumbnail_method(self) -> Optional[str]: + def thumbnail_method(self) -> str | None: if not self.thumbnail: return None return self.thumbnail.method @property - def thumbnail_type(self) -> Optional[str]: + def thumbnail_type(self) -> str | None: if not self.thumbnail: return None return self.thumbnail.type @property - def thumbnail_length(self) -> Optional[int]: + def thumbnail_length(self) -> int | None: if not self.thumbnail: return None return self.thumbnail.length -def get_filename_from_headers(headers: dict[bytes, list[bytes]]) -> Optional[str]: +def get_filename_from_headers(headers: dict[bytes, list[bytes]]) -> str | None: """ Get the filename of the downloaded file by inspecting the Content-Disposition HTTP header. @@ -703,9 +702,9 @@ class ThreadedFileSender: self.clock = hs.get_clock() self.thread_pool = hs.get_media_sender_thread_pool() - self.file: Optional[BinaryIO] = None + self.file: BinaryIO | None = None self.deferred: "Deferred[None]" = Deferred() - self.consumer: Optional[interfaces.IConsumer] = None + self.consumer: interfaces.IConsumer | None = None # Signals if the thread should keep reading/sending data. Set means # continue, clear means pause. diff --git a/synapse/media/filepath.py b/synapse/media/filepath.py index 7659971661..df637f3be3 100644 --- a/synapse/media/filepath.py +++ b/synapse/media/filepath.py @@ -24,7 +24,7 @@ import functools import os import re import string -from typing import Any, Callable, TypeVar, Union, cast +from typing import Any, Callable, TypeVar, cast NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") @@ -46,7 +46,7 @@ def _wrap_in_base_path(func: F) -> F: GetPathMethod = TypeVar( - "GetPathMethod", bound=Union[Callable[..., str], Callable[..., list[str]]] + "GetPathMethod", bound=Callable[..., str] | Callable[..., list[str]] ) @@ -73,7 +73,7 @@ def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMe @functools.wraps(func) def _wrapped( self: "MediaFilePaths", *args: Any, **kwargs: Any - ) -> Union[str, list[str]]: + ) -> str | list[str]: path_or_paths = func(self, *args, **kwargs) if isinstance(path_or_paths, list): diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index eda1410767..7b4408b2bc 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -24,7 +24,7 @@ import logging import os import shutil from io import BytesIO -from typing import IO, TYPE_CHECKING, Optional +from typing import IO, TYPE_CHECKING import attr from matrix_common.types.mxc_uri import MXCUri @@ -170,7 +170,7 @@ class MediaRepository: ) if hs.config.media.url_preview_enabled: - self.url_previewer: Optional[UrlPreviewer] = UrlPreviewer( + self.url_previewer: UrlPreviewer | None = UrlPreviewer( hs, self, self.media_storage ) else: @@ -208,7 +208,7 @@ class MediaRepository: local_media, remote_media, self.clock.time_msec() ) - def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: + def mark_recently_accessed(self, server_name: str | None, media_id: str) -> None: """Mark the given media as recently accessed. Args: @@ -298,11 +298,11 @@ class MediaRepository: async def create_or_update_content( self, media_type: str, - upload_name: Optional[str], + upload_name: str | None, content: IO, content_length: int, auth_user: UserID, - media_id: Optional[str] = None, + media_id: str | None = None, ) -> MXCUri: """Create or update the content of the given media ID. @@ -354,7 +354,7 @@ class MediaRepository: # This is the total size of media uploaded by the user in the last # `time_period_ms` milliseconds, or None if we haven't checked yet. - uploaded_media_size: Optional[int] = None + uploaded_media_size: int | None = None for limit in media_upload_limits: # We only need to check the amount of media uploaded by the user in @@ -422,7 +422,7 @@ class MediaRepository: async def get_cached_remote_media_info( self, origin: str, media_id: str - ) -> Optional[RemoteMedia]: + ) -> RemoteMedia | None: """ Get cached remote media info for a given origin/media ID combo. If the requested media is not found locally, it will not be requested over federation and the @@ -439,7 +439,7 @@ class MediaRepository: async def get_local_media_info( self, request: SynapseRequest, media_id: str, max_timeout_ms: int - ) -> Optional[LocalMedia]: + ) -> LocalMedia | None: """Gets the info dictionary for given local media ID. If the media has not been uploaded yet, this function will wait up to ``max_timeout_ms`` milliseconds for the media to be uploaded. @@ -495,7 +495,7 @@ class MediaRepository: self, request: SynapseRequest, media_id: str, - name: Optional[str], + name: str | None, max_timeout_ms: int, allow_authenticated: bool = True, federation: bool = False, @@ -555,7 +555,7 @@ class MediaRepository: request: SynapseRequest, server_name: str, media_id: str, - name: Optional[str], + name: str | None, max_timeout_ms: int, ip_address: str, use_federation_endpoint: bool, @@ -696,7 +696,7 @@ class MediaRepository: ip_address: str, use_federation_endpoint: bool, allow_authenticated: bool, - ) -> tuple[Optional[Responder], RemoteMedia]: + ) -> tuple[Responder | None, RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -1065,7 +1065,7 @@ class MediaRepository: t_height: int, t_method: str, t_type: str, - ) -> Optional[BytesIO]: + ) -> BytesIO | None: m_width = thumbnailer.width m_height = thumbnailer.height @@ -1099,7 +1099,7 @@ class MediaRepository: t_method: str, t_type: str, url_cache: bool, - ) -> Optional[tuple[str, FileInfo]]: + ) -> tuple[str, FileInfo] | None: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) @@ -1175,7 +1175,7 @@ class MediaRepository: t_height: int, t_method: str, t_type: str, - ) -> Optional[str]: + ) -> str | None: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id) ) @@ -1247,12 +1247,12 @@ class MediaRepository: @trace async def _generate_thumbnails( self, - server_name: Optional[str], + server_name: str | None, media_id: str, file_id: str, media_type: str, url_cache: bool = False, - ) -> Optional[dict]: + ) -> dict | None: """Generate and store thumbnails for an image. Args: diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index f6be9edf50..bc12212c46 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -34,9 +34,7 @@ from typing import ( AsyncIterator, BinaryIO, Callable, - Optional, Sequence, - Union, cast, ) from uuid import uuid4 @@ -79,7 +77,7 @@ class SHA256TransparentIOWriter: self._hash = hashlib.sha256() self._source = source - def write(self, buffer: Union[bytes, bytearray]) -> int: + def write(self, buffer: bytes | bytearray) -> int: """Wrapper for source.write() Args: @@ -260,7 +258,7 @@ class MediaStorage: raise e from None - async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: + async def fetch_media(self, file_info: FileInfo) -> Responder | None: """Attempts to fetch media described by file_info from the local cache and configured storage providers. @@ -420,9 +418,9 @@ class FileResponder(Responder): def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self.open_file.close() @@ -476,7 +474,7 @@ class MultipartFileConsumer: file_content_type: str, json_object: JsonDict, disposition: str, - content_length: Optional[int], + content_length: int | None, ) -> None: self.clock = clock self.wrapped_consumer = wrapped_consumer @@ -488,8 +486,8 @@ class MultipartFileConsumer: # The producer that registered with us, and if it's a push or pull # producer. - self.producer: Optional["interfaces.IProducer"] = None - self.streaming: Optional[bool] = None + self.producer: "interfaces.IProducer" | None = None + self.streaming: bool | None = None # Whether the wrapped consumer has asked us to pause. self.paused = False @@ -618,7 +616,7 @@ class MultipartFileConsumer: # repeatedly calling `resumeProducing` in a loop. run_in_background(self._resumeProducingRepeatedly) - def content_length(self) -> Optional[int]: + def content_length(self) -> int | None: """ Calculate the content length of the multipart response in bytes. @@ -671,7 +669,7 @@ class Header: self, name: bytes, value: Any, - params: Optional[list[tuple[Any, Any]]] = None, + params: list[tuple[Any, Any]] | None = None, ): self.name = name self.value = value @@ -693,7 +691,7 @@ class Header: return h.read() -def escape(value: Union[str, bytes]) -> str: +def escape(value: str | bytes) -> str: """ This function prevents header values from corrupting the request, a newline in the file name parameter makes form-data request unreadable diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py index 059d8ad1cf..7e44072130 100644 --- a/synapse/media/oembed.py +++ b/synapse/media/oembed.py @@ -21,7 +21,7 @@ import html import logging import urllib.parse -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast import attr @@ -42,12 +42,12 @@ class OEmbedResult: # The Open Graph result (converted from the oEmbed result). open_graph_result: JsonDict # The author_name of the oEmbed result - author_name: Optional[str] + author_name: str | None # Number of milliseconds to cache the content, according to the oEmbed response. # # This will be None if no cache-age is provided in the oEmbed response (or # if the oEmbed response cannot be turned into an Open Graph response). - cache_age: Optional[int] + cache_age: int | None class OEmbedProvider: @@ -80,7 +80,7 @@ class OEmbedProvider: for pattern in oembed_endpoint.url_patterns: self._oembed_patterns[pattern] = api_endpoint - def get_oembed_url(self, url: str) -> Optional[str]: + def get_oembed_url(self, url: str) -> str | None: """ Check whether the URL should be downloaded as oEmbed content instead. @@ -105,7 +105,7 @@ class OEmbedProvider: # No match. return None - def autodiscover_from_html(self, tree: "etree._Element") -> Optional[str]: + def autodiscover_from_html(self, tree: "etree._Element") -> str | None: """ Search an HTML document for oEmbed autodiscovery information. diff --git a/synapse/media/preview_html.py b/synapse/media/preview_html.py index 6a8e479152..22ad581f82 100644 --- a/synapse/media/preview_html.py +++ b/synapse/media/preview_html.py @@ -27,7 +27,6 @@ from typing import ( Generator, Iterable, Optional, - Union, cast, ) @@ -48,7 +47,7 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} -def _normalise_encoding(encoding: str) -> Optional[str]: +def _normalise_encoding(encoding: str) -> str | None: """Use the Python codec's name as the normalised entry.""" try: return codecs.lookup(encoding).name @@ -56,9 +55,7 @@ def _normalise_encoding(encoding: str) -> Optional[str]: return None -def _get_html_media_encodings( - body: bytes, content_type: Optional[str] -) -> Iterable[str]: +def _get_html_media_encodings(body: bytes, content_type: str | None) -> Iterable[str]: """ Get potential encoding of the body based on the (presumably) HTML body or the content-type header. @@ -119,7 +116,7 @@ def _get_html_media_encodings( def decode_body( - body: bytes, uri: str, content_type: Optional[str] = None + body: bytes, uri: str, content_type: str | None = None ) -> Optional["etree._Element"]: """ This uses lxml to parse the HTML document. @@ -186,8 +183,8 @@ def _get_meta_tags( tree: "etree._Element", property: str, prefix: str, - property_mapper: Optional[Callable[[str], Optional[str]]] = None, -) -> dict[str, Optional[str]]: + property_mapper: Callable[[str], str | None] | None = None, +) -> dict[str, str | None]: """ Search for meta tags prefixed with a particular string. @@ -202,9 +199,9 @@ def _get_meta_tags( Returns: A map of tag name to value. """ - # This actually returns Dict[str, str], but the caller sets this as a variable - # which is Dict[str, Optional[str]]. - results: dict[str, Optional[str]] = {} + # This actually returns dict[str, str], but the caller sets this as a variable + # which is dict[str, str | None]. + results: dict[str, str | None] = {} # Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this. for tag in cast( list["etree._Element"], @@ -233,7 +230,7 @@ def _get_meta_tags( return results -def _map_twitter_to_open_graph(key: str) -> Optional[str]: +def _map_twitter_to_open_graph(key: str) -> str | None: """ Map a Twitter card property to the analogous Open Graph property. @@ -253,7 +250,7 @@ def _map_twitter_to_open_graph(key: str) -> Optional[str]: return "og" + key[7:] -def parse_html_to_open_graph(tree: "etree._Element") -> dict[str, Optional[str]]: +def parse_html_to_open_graph(tree: "etree._Element") -> dict[str, str | None]: """ Parse the HTML document into an Open Graph response. @@ -387,7 +384,7 @@ def parse_html_to_open_graph(tree: "etree._Element") -> dict[str, Optional[str]] return og -def parse_html_description(tree: "etree._Element") -> Optional[str]: +def parse_html_description(tree: "etree._Element") -> str | None: """ Calculate a text description based on an HTML document. @@ -460,7 +457,7 @@ def _iterate_over_text( # This is a stack whose items are elements to iterate over *or* strings # to be returned. - elements: list[Union[str, "etree._Element"]] = [tree] + elements: list[str | "etree._Element"] = [tree] while elements: el = elements.pop() @@ -496,7 +493,7 @@ def _iterate_over_text( def summarize_paragraphs( text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 -) -> Optional[str]: +) -> str | None: """ Try to get a summary respecting first paragraph and then word boundaries. diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py index 300952025a..a87ffa0892 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py @@ -23,7 +23,7 @@ import abc import logging import os import shutil -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -55,7 +55,7 @@ class StorageProvider(metaclass=abc.ABCMeta): """ @abc.abstractmethod - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + async def fetch(self, path: str, file_info: FileInfo) -> Responder | None: """Attempt to fetch the file described by file_info and stream it into writer. @@ -124,7 +124,7 @@ class StorageProviderWrapper(StorageProvider): run_in_background(store) @trace_with_opname("StorageProviderWrapper.fetch") - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + async def fetch(self, path: str, file_info: FileInfo) -> Responder | None: if file_info.url_cache: # Files in the URL preview cache definitely aren't stored here, # so avoid any potentially slow I/O or network access. @@ -173,7 +173,7 @@ class FileStorageProviderBackend(StorageProvider): ) @trace_with_opname("FileStorageProviderBackend.fetch") - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + async def fetch(self, path: str, file_info: FileInfo) -> Responder | None: """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index a42d39c319..fd65131c63 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -22,7 +22,7 @@ import logging from io import BytesIO from types import TracebackType -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from PIL import Image @@ -237,9 +237,9 @@ class Thumbnailer: def __exit__( self, - type: Optional[type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.close() @@ -549,8 +549,8 @@ class ThumbnailProvider: file_id: str, url_cache: bool, for_federation: bool, - media_info: Optional[LocalMedia] = None, - server_name: Optional[str] = None, + media_info: LocalMedia | None = None, + server_name: str | None = None, ) -> None: """ Respond to a request with an appropriate thumbnail from the previously generated thumbnails. @@ -713,8 +713,8 @@ class ThumbnailProvider: thumbnail_infos: list[ThumbnailInfo], file_id: str, url_cache: bool, - server_name: Optional[str], - ) -> Optional[FileInfo]: + server_name: str | None, + ) -> FileInfo | None: """ Choose an appropriate thumbnail from the previously generated thumbnails. @@ -742,11 +742,11 @@ class ThumbnailProvider: if desired_method == "crop": # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list: list[ - tuple[int, int, int, bool, Optional[int], ThumbnailInfo] + tuple[int, int, int, bool, int | None, ThumbnailInfo] ] = [] # Other thumbnails. crop_info_list2: list[ - tuple[int, int, int, bool, Optional[int], ThumbnailInfo] + tuple[int, int, int, bool, int | None, ThumbnailInfo] ] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 2a63842fb7..bbd8017b13 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -28,7 +28,7 @@ import re import shutil import sys import traceback -from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional +from typing import TYPE_CHECKING, BinaryIO, Iterable from urllib.parse import urljoin, urlparse, urlsplit from urllib.request import urlopen @@ -70,9 +70,9 @@ class DownloadResult: uri: str response_code: int media_type: str - download_name: Optional[str] + download_name: str | None expires: int - etag: Optional[str] + etag: str | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -87,7 +87,7 @@ class MediaInfo: media_length: int # The media filename, according to the server. This is parsed from the # returned headers, if possible. - download_name: Optional[str] + download_name: str | None # The time of the preview. created_ts_ms: int # Information from the media storage provider about where the file is stored @@ -101,7 +101,7 @@ class MediaInfo: # The timestamp (in milliseconds) of when this preview expires. expires: int # The ETag header of the response. - etag: Optional[str] + etag: str | None class UrlPreviewer: @@ -268,7 +268,7 @@ class UrlPreviewer: # The number of milliseconds that the response should be considered valid. expiration_ms = media_info.expires - author_name: Optional[str] = None + author_name: str | None = None if _is_media(media_info.media_type): file_id = media_info.filesystem_id @@ -705,7 +705,7 @@ class UrlPreviewer: async def _handle_oembed_response( self, url: str, media_info: MediaInfo, expiration_ms: int - ) -> tuple[JsonDict, Optional[str], int]: + ) -> tuple[JsonDict, str | None, int]: """ Parse the downloaded oEmbed info. diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index def21ac942..cf7b2f1da0 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -31,10 +31,8 @@ from typing import ( Generic, Iterable, Mapping, - Optional, Sequence, TypeVar, - Union, cast, ) @@ -156,12 +154,10 @@ class LaterGauge(Collector): name: str desc: str - labelnames: Optional[StrSequence] = attr.ib(hash=False) + labelnames: StrSequence | None = attr.ib(hash=False) _instance_id_to_hook_map: dict[ - Optional[str], # instance_id - Callable[ - [], Union[Mapping[tuple[str, ...], Union[int, float]], Union[int, float]] - ], + str | None, # instance_id + Callable[[], Mapping[tuple[str, ...], int | float] | int | float], ] = attr.ib(factory=dict, hash=False) """ Map from homeserver instance_id to a callback. Each callback should either return a @@ -200,10 +196,8 @@ class LaterGauge(Collector): def register_hook( self, *, - homeserver_instance_id: Optional[str], - hook: Callable[ - [], Union[Mapping[tuple[str, ...], Union[int, float]], Union[int, float]] - ], + homeserver_instance_id: str | None, + hook: Callable[[], Mapping[tuple[str, ...], int | float] | int | float], ) -> None: """ Register a callback/hook that will be called to generate a metric samples for @@ -420,7 +414,7 @@ class GaugeHistogramMetricFamilyWithLabels(GaugeHistogramMetricFamily): name: str, documentation: str, gsum_value: float, - buckets: Optional[Sequence[tuple[str, float]]] = None, + buckets: Sequence[tuple[str, float]] | None = None, labelnames: StrSequence = (), labelvalues: StrSequence = (), unit: str = "", @@ -471,7 +465,7 @@ class GaugeBucketCollector(Collector): *, name: str, documentation: str, - labelnames: Optional[StrSequence], + labelnames: StrSequence | None, buckets: Iterable[float], registry: CollectorRegistry = REGISTRY, ): @@ -497,7 +491,7 @@ class GaugeBucketCollector(Collector): # We initially set this to None. We won't report metrics until # this has been initialised after a successful data update - self._metric: Optional[GaugeHistogramMetricFamilyWithLabels] = None + self._metric: GaugeHistogramMetricFamilyWithLabels | None = None registry.register(self) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index c871598680..8ff2803455 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -34,7 +34,6 @@ from typing import ( Optional, Protocol, TypeVar, - Union, ) from prometheus_client import Metric @@ -188,7 +187,7 @@ class _BackgroundProcess: self.desc = desc self.server_name = server_name self._context = ctx - self._reported_stats: Optional[ContextResourceUsage] = None + self._reported_stats: ContextResourceUsage | None = None def update_metrics(self) -> None: """Updates the metrics with values from this process.""" @@ -224,12 +223,12 @@ R = TypeVar("R") def run_as_background_process( desc: "LiteralString", server_name: str, - func: Callable[..., Awaitable[Optional[R]]], + func: Callable[..., Awaitable[R | None]], *args: Any, bg_start_span: bool = True, test_only_tracer: Optional["opentracing.Tracer"] = None, **kwargs: Any, -) -> "defer.Deferred[Optional[R]]": +) -> "defer.Deferred[R | None]": """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -270,7 +269,7 @@ def run_as_background_process( # trace. original_active_tracing_span = active_span(tracer=test_only_tracer) - async def run() -> Optional[R]: + async def run() -> R | None: with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -425,8 +424,8 @@ class HasHomeServer(Protocol): def wrap_as_background_process( desc: "LiteralString", ) -> Callable[ - [Callable[P, Awaitable[Optional[R]]]], - Callable[P, "defer.Deferred[Optional[R]]"], + [Callable[P, Awaitable[R | None]]], + Callable[P, "defer.Deferred[R | None]"], ]: """Decorator that wraps an asynchronous function `func`, returning a synchronous decorated function. Calling the decorated version runs `func` as a background @@ -448,12 +447,12 @@ def wrap_as_background_process( """ def wrapper( - func: Callable[Concatenate[HasHomeServer, P], Awaitable[Optional[R]]], - ) -> Callable[P, "defer.Deferred[Optional[R]]"]: + func: Callable[Concatenate[HasHomeServer, P], Awaitable[R | None]], + ) -> Callable[P, "defer.Deferred[R | None]"]: @wraps(func) def wrapped_func( self: HasHomeServer, *args: P.args, **kwargs: P.kwargs - ) -> "defer.Deferred[Optional[R]]": + ) -> "defer.Deferred[R | None]": assert self.hs is not None, ( "The `hs` attribute must be set on the object where `@wrap_as_background_process` decorator is used." ) @@ -487,7 +486,7 @@ class BackgroundProcessLoggingContext(LoggingContext): *, name: str, server_name: str, - instance_id: Optional[Union[int, str]] = None, + instance_id: int | str | None = None, ): """ @@ -503,11 +502,11 @@ class BackgroundProcessLoggingContext(LoggingContext): if instance_id is None: instance_id = id(self) super().__init__(name="%s-%s" % (name, instance_id), server_name=server_name) - self._proc: Optional[_BackgroundProcess] = _BackgroundProcess( + self._proc: _BackgroundProcess | None = _BackgroundProcess( desc=name, server_name=server_name, ctx=self ) - def start(self, rusage: "Optional[resource.struct_rusage]") -> None: + def start(self, rusage: "resource.struct_rusage | None") -> None: """Log context has started running (again).""" super().start(rusage) @@ -528,9 +527,9 @@ class BackgroundProcessLoggingContext(LoggingContext): def __exit__( self, - type: Optional[type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], + type: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, ) -> None: """Log context has finished.""" diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index fb8adbe060..03cecec3ca 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -23,7 +23,7 @@ import ctypes import logging import os import re -from typing import Iterable, Literal, Optional, overload +from typing import Iterable, Literal, overload import attr from prometheus_client import REGISTRY, Metric @@ -40,17 +40,17 @@ class JemallocStats: @overload def _mallctl( - self, name: str, read: Literal[True] = True, write: Optional[int] = None + self, name: str, read: Literal[True] = True, write: int | None = None ) -> int: ... @overload def _mallctl( - self, name: str, read: Literal[False], write: Optional[int] = None + self, name: str, read: Literal[False], write: int | None = None ) -> None: ... def _mallctl( - self, name: str, read: bool = True, write: Optional[int] = None - ) -> Optional[int]: + self, name: str, read: bool = True, write: int | None = None + ) -> int | None: """Wrapper around `mallctl` for reading and writing integers to jemalloc. @@ -131,10 +131,10 @@ class JemallocStats: return self._mallctl(f"stats.{name}") -_JEMALLOC_STATS: Optional[JemallocStats] = None +_JEMALLOC_STATS: JemallocStats | None = None -def get_jemalloc_stats() -> Optional[JemallocStats]: +def get_jemalloc_stats() -> JemallocStats | None: """Returns an interface to jemalloc, if it is being used. Note that this will always return None until `setup_jemalloc_stats` has been diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 9287747cea..6a2d152e3f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -29,9 +29,7 @@ from typing import ( Generator, Iterable, Mapping, - Optional, TypeVar, - Union, ) import attr @@ -227,11 +225,11 @@ class UserIpAndAgent: def run_as_background_process( desc: "LiteralString", - func: Callable[..., Awaitable[Optional[T]]], + func: Callable[..., Awaitable[T | None]], *args: Any, bg_start_span: bool = True, **kwargs: Any, -) -> "defer.Deferred[Optional[T]]": +) -> "defer.Deferred[T | None]": """ XXX: Deprecated: use `ModuleApi.run_as_background_process` instead. @@ -295,8 +293,8 @@ def run_as_background_process( def cached( *, max_entries: int = 1000, - num_args: Optional[int] = None, - uncached_args: Optional[Collection[str]] = None, + num_args: int | None = None, + uncached_args: Collection[str] | None = None, ) -> Callable[[F], CachedFunction[F]]: """Returns a decorator that applies a memoizing cache around the function. This decorator behaves similarly to functools.lru_cache. @@ -338,7 +336,7 @@ class ModuleApi: # TODO: Fix this type hint once the types for the data stores have been ironed # out. - self._store: Union[DataStore, "GenericWorkerStore"] = hs.get_datastores().main + self._store: DataStore | "GenericWorkerStore" = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._auth_handler = auth_handler @@ -387,26 +385,20 @@ class ModuleApi: def register_spam_checker_callbacks( self, *, - check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, - should_drop_federated_event: Optional[ - SHOULD_DROP_FEDERATED_EVENT_CALLBACK - ] = None, - user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, - user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, - federated_user_may_invite: Optional[FEDERATED_USER_MAY_INVITE_CALLBACK] = None, - user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, - user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, - user_may_create_room_alias: Optional[ - USER_MAY_CREATE_ROOM_ALIAS_CALLBACK - ] = None, - user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None, - user_may_send_state_event: Optional[USER_MAY_SEND_STATE_EVENT_CALLBACK] = None, - check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None, - check_registration_for_spam: Optional[ - CHECK_REGISTRATION_FOR_SPAM_CALLBACK - ] = None, - check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, - check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None, + check_event_for_spam: CHECK_EVENT_FOR_SPAM_CALLBACK | None = None, + should_drop_federated_event: SHOULD_DROP_FEDERATED_EVENT_CALLBACK | None = None, + user_may_join_room: USER_MAY_JOIN_ROOM_CALLBACK | None = None, + user_may_invite: USER_MAY_INVITE_CALLBACK | None = None, + federated_user_may_invite: FEDERATED_USER_MAY_INVITE_CALLBACK | None = None, + user_may_send_3pid_invite: USER_MAY_SEND_3PID_INVITE_CALLBACK | None = None, + user_may_create_room: USER_MAY_CREATE_ROOM_CALLBACK | None = None, + user_may_create_room_alias: USER_MAY_CREATE_ROOM_ALIAS_CALLBACK | None = None, + user_may_publish_room: USER_MAY_PUBLISH_ROOM_CALLBACK | None = None, + user_may_send_state_event: USER_MAY_SEND_STATE_EVENT_CALLBACK | None = None, + check_username_for_spam: CHECK_USERNAME_FOR_SPAM_CALLBACK | None = None, + check_registration_for_spam: CHECK_REGISTRATION_FOR_SPAM_CALLBACK | None = None, + check_media_file_for_spam: CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK | None = None, + check_login_for_spam: CHECK_LOGIN_FOR_SPAM_CALLBACK | None = None, ) -> None: """Registers callbacks for spam checking capabilities. @@ -432,12 +424,12 @@ class ModuleApi: def register_account_validity_callbacks( self, *, - is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, - on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, - on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None, - on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, - on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, - on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + is_user_expired: IS_USER_EXPIRED_CALLBACK | None = None, + on_user_registration: ON_USER_REGISTRATION_CALLBACK | None = None, + on_user_login: ON_USER_LOGIN_CALLBACK | None = None, + on_legacy_send_mail: ON_LEGACY_SEND_MAIL_CALLBACK | None = None, + on_legacy_renew: ON_LEGACY_RENEW_CALLBACK | None = None, + on_legacy_admin_request: ON_LEGACY_ADMIN_REQUEST | None = None, ) -> None: """Registers callbacks for account validity capabilities. @@ -455,9 +447,8 @@ class ModuleApi: def register_ratelimit_callbacks( self, *, - get_ratelimit_override_for_user: Optional[ - GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK - ] = None, + get_ratelimit_override_for_user: GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK + | None = None, ) -> None: """Registers callbacks for ratelimit capabilities. Added in Synapse v1.132.0. @@ -469,16 +460,13 @@ class ModuleApi: def register_media_repository_callbacks( self, *, - get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None, - is_user_allowed_to_upload_media_of_size: Optional[ - IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK - ] = None, - get_media_upload_limits_for_user: Optional[ - GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK - ] = None, - on_media_upload_limit_exceeded: Optional[ - ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK - ] = None, + get_media_config_for_user: GET_MEDIA_CONFIG_FOR_USER_CALLBACK | None = None, + is_user_allowed_to_upload_media_of_size: IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + | None = None, + get_media_upload_limits_for_user: GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK + | None = None, + on_media_upload_limit_exceeded: ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK + | None = None, ) -> None: """Registers callbacks for media repository capabilities. Added in Synapse v1.132.0. @@ -493,28 +481,23 @@ class ModuleApi: def register_third_party_rules_callbacks( self, *, - check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, - on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, - check_threepid_can_be_invited: Optional[ - CHECK_THREEPID_CAN_BE_INVITED_CALLBACK - ] = None, - check_visibility_can_be_modified: Optional[ - CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK - ] = None, - on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, - check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, - check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, - on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, - on_user_deactivation_status_changed: Optional[ - ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK - ] = None, - on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, - on_add_user_third_party_identifier: Optional[ - ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, - on_remove_user_third_party_identifier: Optional[ - ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, + check_event_allowed: CHECK_EVENT_ALLOWED_CALLBACK | None = None, + on_create_room: ON_CREATE_ROOM_CALLBACK | None = None, + check_threepid_can_be_invited: CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + | None = None, + check_visibility_can_be_modified: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + | None = None, + on_new_event: ON_NEW_EVENT_CALLBACK | None = None, + check_can_shutdown_room: CHECK_CAN_SHUTDOWN_ROOM_CALLBACK | None = None, + check_can_deactivate_user: CHECK_CAN_DEACTIVATE_USER_CALLBACK | None = None, + on_profile_update: ON_PROFILE_UPDATE_CALLBACK | None = None, + on_user_deactivation_status_changed: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + | None = None, + on_threepid_bind: ON_THREEPID_BIND_CALLBACK | None = None, + on_add_user_third_party_identifier: ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + | None = None, + on_remove_user_third_party_identifier: ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + | None = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -538,8 +521,8 @@ class ModuleApi: def register_presence_router_callbacks( self, *, - get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, - get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, + get_users_for_states: GET_USERS_FOR_STATES_CALLBACK | None = None, + get_interested_users: GET_INTERESTED_USERS_CALLBACK | None = None, ) -> None: """Registers callbacks for presence router capabilities. @@ -553,18 +536,15 @@ class ModuleApi: def register_password_auth_provider_callbacks( self, *, - check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, - on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, - auth_checkers: Optional[ - dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK] - ] = None, - is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None, - get_username_for_registration: Optional[ - GET_USERNAME_FOR_REGISTRATION_CALLBACK - ] = None, - get_displayname_for_registration: Optional[ - GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK - ] = None, + check_3pid_auth: CHECK_3PID_AUTH_CALLBACK | None = None, + on_logged_out: ON_LOGGED_OUT_CALLBACK | None = None, + auth_checkers: dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK] + | None = None, + is_3pid_allowed: IS_3PID_ALLOWED_CALLBACK | None = None, + get_username_for_registration: GET_USERNAME_FOR_REGISTRATION_CALLBACK + | None = None, + get_displayname_for_registration: GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + | None = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -588,8 +568,8 @@ class ModuleApi: self, *, on_update: ON_UPDATE_CALLBACK, - default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, - min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None, + default_batch_size: DEFAULT_BATCH_SIZE_CALLBACK | None = None, + min_batch_size: MIN_BATCH_SIZE_CALLBACK | None = None, ) -> None: """Registers background update controller callbacks. @@ -606,7 +586,7 @@ class ModuleApi: def register_account_data_callbacks( self, *, - on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None, + on_account_data_updated: ON_ACCOUNT_DATA_UPDATED_CALLBACK | None = None, ) -> None: """Registers account data callbacks. @@ -635,9 +615,8 @@ class ModuleApi: def register_add_extra_fields_to_unsigned_client_event_callbacks( self, *, - add_field_to_unsigned_callback: Optional[ - ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK - ] = None, + add_field_to_unsigned_callback: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK + | None = None, ) -> None: """Registers a callback that can be used to add fields to the unsigned section of events. @@ -708,7 +687,7 @@ class ModuleApi: return self._server_name @property - def worker_name(self) -> Optional[str]: + def worker_name(self) -> str | None: """The name of the worker this specific instance is running as per the "worker_name" configuration setting, or None if it's the main process. @@ -717,7 +696,7 @@ class ModuleApi: return self._hs.config.worker.worker_name @property - def worker_app(self) -> Optional[str]: + def worker_app(self) -> str | None: """The name of the worker app this specific instance is running as per the "worker_app" configuration setting, or None if it's the main process. @@ -725,7 +704,7 @@ class ModuleApi: """ return self._hs.config.worker.worker_app - async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + async def get_userinfo_by_id(self, user_id: str) -> UserInfo | None: """Get user info by user_id Added in Synapse v1.41.0. @@ -843,7 +822,7 @@ class ModuleApi: """ return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)] - def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": + def check_user_exists(self, user_id: str) -> "defer.Deferred[str | None]": """Check if user exists. Added in Synapse v0.25.0. @@ -861,8 +840,8 @@ class ModuleApi: def register( self, localpart: str, - displayname: Optional[str] = None, - emails: Optional[list[str]] = None, + displayname: str | None = None, + emails: list[str] | None = None, ) -> Generator["defer.Deferred[Any]", Any, tuple[str, str]]: """Registers a new user with given localpart and optional displayname, emails. @@ -892,8 +871,8 @@ class ModuleApi: def register_user( self, localpart: str, - displayname: Optional[str] = None, - emails: Optional[list[str]] = None, + displayname: str | None = None, + emails: list[str] | None = None, admin: bool = False, ) -> "defer.Deferred[str]": """Registers a new user with given localpart and optional displayname, emails. @@ -926,9 +905,9 @@ class ModuleApi: def register_device( self, user_id: str, - device_id: Optional[str] = None, - initial_display_name: Optional[str] = None, - ) -> "defer.Deferred[tuple[str, str, Optional[int], Optional[str]]]": + device_id: str | None = None, + initial_display_name: str | None = None, + ) -> "defer.Deferred[tuple[str, str, int | None, str | None]]": """Register a device for a user and generate an access token. Added in Synapse v1.2.0. @@ -978,8 +957,8 @@ class ModuleApi: self, user_id: str, duration_in_ms: int = (2 * 60 * 1000), - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + auth_provider_id: str | None = None, + auth_provider_session_id: str | None = None, ) -> str: """Create a login token suitable for m.login.token authentication @@ -1135,7 +1114,7 @@ class ModuleApi: @defer.inlineCallbacks def get_state_events_in_room( - self, room_id: str, types: Iterable[tuple[str, Optional[str]]] + self, room_id: str, types: Iterable[tuple[str, str | None]] ) -> Generator[defer.Deferred, Any, Iterable[EventBase]]: """Gets current state events for the given room. @@ -1166,8 +1145,8 @@ class ModuleApi: target: str, room_id: str, new_membership: str, - content: Optional[JsonDict] = None, - remote_room_hosts: Optional[list[str]] = None, + content: JsonDict | None = None, + remote_room_hosts: list[str] | None = None, ) -> EventBase: """Updates the membership of a user to the given value. @@ -1343,7 +1322,7 @@ class ModuleApi: ) async def set_presence_for_users( - self, users: Mapping[str, tuple[str, Optional[str]]] + self, users: Mapping[str, tuple[str, str | None]] ) -> None: """ Update the internal presence state of users. @@ -1378,7 +1357,7 @@ class ModuleApi: f: Callable, msec: float, *args: object, - desc: Optional[str] = None, + desc: str | None = None, run_on_all_instances: bool = False, **kwargs: object, ) -> None: @@ -1437,7 +1416,7 @@ class ModuleApi: msec: float, f: Callable, *args: object, - desc: Optional[str] = None, + desc: str | None = None, **kwargs: object, ) -> IDelayedCall: """Wraps a function as a background process and calls it in a given number of milliseconds. @@ -1483,10 +1462,10 @@ class ModuleApi: async def send_http_push_notification( self, user_id: str, - device_id: Optional[str], + device_id: str | None, content: JsonDict, - tweaks: Optional[JsonMapping] = None, - default_payload: Optional[JsonMapping] = None, + tweaks: JsonMapping | None = None, + default_payload: JsonMapping | None = None, ) -> dict[str, bool]: """Send an HTTP push notification that is forwarded to the registered push gateway for the specified user/device. @@ -1552,7 +1531,7 @@ class ModuleApi: def read_templates( self, filenames: list[str], - custom_template_directory: Optional[str] = None, + custom_template_directory: str | None = None, ) -> list[jinja2.Template]: """Read and load the content of the template files at the given location. By default, Synapse will look for these templates in its configured template @@ -1573,7 +1552,7 @@ class ModuleApi: (td for td in (self.custom_template_dir, custom_template_directory) if td), ) - def is_mine(self, id: Union[str, DomainSpecificString]) -> bool: + def is_mine(self, id: str | DomainSpecificString) -> bool: """ Checks whether an ID (user id, room, ...) comes from this homeserver. @@ -1635,7 +1614,7 @@ class ModuleApi: async def get_room_state( self, room_id: str, - event_filter: Optional[Iterable[tuple[str, Optional[str]]]] = None, + event_filter: Iterable[tuple[str, str | None]] | None = None, ) -> StateMap[EventBase]: """Returns the current state of the given room. @@ -1677,11 +1656,11 @@ class ModuleApi: def run_as_background_process( self, desc: "LiteralString", - func: Callable[..., Awaitable[Optional[T]]], + func: Callable[..., Awaitable[T | None]], *args: Any, bg_start_span: bool = True, **kwargs: Any, - ) -> "defer.Deferred[Optional[T]]": + ) -> "defer.Deferred[T | None]": """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -1799,9 +1778,7 @@ class ModuleApi: """ await self._store.add_user_bound_threepid(user_id, medium, address, id_server) - def check_push_rule_actions( - self, actions: list[Union[str, dict[str, str]]] - ) -> None: + def check_push_rule_actions(self, actions: list[str | dict[str, str]]) -> None: """Checks if the given push rule actions are valid according to the Matrix specification. @@ -1824,7 +1801,7 @@ class ModuleApi: scope: str, kind: str, rule_id: str, - actions: list[Union[str, dict[str, str]]], + actions: list[str | dict[str, str]], ) -> None: """Changes the actions of an existing push rule for the given user. @@ -1862,7 +1839,7 @@ class ModuleApi: ) async def get_monthly_active_users_by_service( - self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None + self, start_timestamp: int | None = None, end_timestamp: int | None = None ) -> list[tuple[str, str]]: """Generates list of monthly active users and their services. Please see corresponding storage docstring for more details. @@ -1883,7 +1860,7 @@ class ModuleApi: start_timestamp, end_timestamp ) - async def get_canonical_room_alias(self, room_id: RoomID) -> Optional[RoomAlias]: + async def get_canonical_room_alias(self, room_id: RoomID) -> RoomAlias | None: """ Retrieve the given room's current canonical alias. @@ -1938,8 +1915,8 @@ class ModuleApi: user_id: str, config: JsonDict, ratelimit: bool = True, - creator_join_profile: Optional[JsonDict] = None, - ) -> tuple[str, Optional[str]]: + creator_join_profile: JsonDict | None = None, + ) -> tuple[str, str | None]: """Creates a new room. Added in Synapse v1.65.0. @@ -2109,7 +2086,7 @@ class AccountDataManager: f"{user_id} is not local to this homeserver; can't access account data for remote users." ) - async def get_global(self, user_id: str, data_type: str) -> Optional[JsonMapping]: + async def get_global(self, user_id: str, data_type: str) -> JsonMapping | None: """ Gets some global account data, of a specified type, for the specified user. diff --git a/synapse/module_api/callbacks/account_validity_callbacks.py b/synapse/module_api/callbacks/account_validity_callbacks.py index da01414d9a..892e9c8ecb 100644 --- a/synapse/module_api/callbacks/account_validity_callbacks.py +++ b/synapse/module_api/callbacks/account_validity_callbacks.py @@ -20,16 +20,16 @@ # import logging -from typing import Awaitable, Callable, Optional +from typing import Awaitable, Callable from twisted.web.http import Request logger = logging.getLogger(__name__) # Types for callbacks to be registered via the module api -IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] +IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[bool | None]] ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] -ON_USER_LOGIN_CALLBACK = Callable[[str, Optional[str], Optional[str]], Awaitable] +ON_USER_LOGIN_CALLBACK = Callable[[str, str | None, str | None], Awaitable] # Temporary hooks to allow for a transition from `/_matrix/client` endpoints # to `/_synapse/client/account_validity`. See `register_callbacks` below. ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] @@ -42,21 +42,21 @@ class AccountValidityModuleApiCallbacks: self.is_user_expired_callbacks: list[IS_USER_EXPIRED_CALLBACK] = [] self.on_user_registration_callbacks: list[ON_USER_REGISTRATION_CALLBACK] = [] self.on_user_login_callbacks: list[ON_USER_LOGIN_CALLBACK] = [] - self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None - self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None + self.on_legacy_send_mail_callback: ON_LEGACY_SEND_MAIL_CALLBACK | None = None + self.on_legacy_renew_callback: ON_LEGACY_RENEW_CALLBACK | None = None # The legacy admin requests callback isn't a protected attribute because we need # to access it from the admin servlet, which is outside of this handler. - self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None + self.on_legacy_admin_request_callback: ON_LEGACY_ADMIN_REQUEST | None = None def register_callbacks( self, - is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, - on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, - on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None, - on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, - on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, - on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + is_user_expired: IS_USER_EXPIRED_CALLBACK | None = None, + on_user_registration: ON_USER_REGISTRATION_CALLBACK | None = None, + on_user_login: ON_USER_LOGIN_CALLBACK | None = None, + on_legacy_send_mail: ON_LEGACY_SEND_MAIL_CALLBACK | None = None, + on_legacy_renew: ON_LEGACY_RENEW_CALLBACK | None = None, + on_legacy_admin_request: ON_LEGACY_ADMIN_REQUEST | None = None, ) -> None: """Register callbacks from module for each hook.""" if is_user_expired is not None: diff --git a/synapse/module_api/callbacks/media_repository_callbacks.py b/synapse/module_api/callbacks/media_repository_callbacks.py index 7cb56e558b..f1e6ea4c38 100644 --- a/synapse/module_api/callbacks/media_repository_callbacks.py +++ b/synapse/module_api/callbacks/media_repository_callbacks.py @@ -13,7 +13,7 @@ # import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Awaitable, Callable from synapse.config.repository import MediaUploadLimit from synapse.types import JsonDict @@ -25,12 +25,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -GET_MEDIA_CONFIG_FOR_USER_CALLBACK = Callable[[str], Awaitable[Optional[JsonDict]]] +GET_MEDIA_CONFIG_FOR_USER_CALLBACK = Callable[[str], Awaitable[JsonDict | None]] IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK = Callable[[str, int], Awaitable[bool]] GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK = Callable[ - [str], Awaitable[Optional[list[MediaUploadLimit]]] + [str], Awaitable[list[MediaUploadLimit] | None] ] ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK = Callable[ @@ -57,16 +57,13 @@ class MediaRepositoryModuleApiCallbacks: def register_callbacks( self, - get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None, - is_user_allowed_to_upload_media_of_size: Optional[ - IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK - ] = None, - get_media_upload_limits_for_user: Optional[ - GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK - ] = None, - on_media_upload_limit_exceeded: Optional[ - ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK - ] = None, + get_media_config_for_user: GET_MEDIA_CONFIG_FOR_USER_CALLBACK | None = None, + is_user_allowed_to_upload_media_of_size: IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + | None = None, + get_media_upload_limits_for_user: GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK + | None = None, + on_media_upload_limit_exceeded: ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK + | None = None, ) -> None: """Register callbacks from module for each hook.""" if get_media_config_for_user is not None: @@ -87,14 +84,14 @@ class MediaRepositoryModuleApiCallbacks: on_media_upload_limit_exceeded ) - async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]: + async def get_media_config_for_user(self, user_id: str) -> JsonDict | None: for callback in self._get_media_config_for_user_callbacks: with Measure( self.clock, name=f"{callback.__module__}.{callback.__qualname__}", server_name=self.server_name, ): - res: Optional[JsonDict] = await delay_cancellation(callback(user_id)) + res: JsonDict | None = await delay_cancellation(callback(user_id)) if res: return res @@ -117,7 +114,7 @@ class MediaRepositoryModuleApiCallbacks: async def get_media_upload_limits_for_user( self, user_id: str - ) -> Optional[list[MediaUploadLimit]]: + ) -> list[MediaUploadLimit] | None: """ Get the first non-None list of MediaUploadLimits for the user from the registered callbacks. If a list is returned it will be sorted in descending order of duration. @@ -128,7 +125,7 @@ class MediaRepositoryModuleApiCallbacks: name=f"{callback.__module__}.{callback.__qualname__}", server_name=self.server_name, ): - res: Optional[list[MediaUploadLimit]] = await delay_cancellation( + res: list[MediaUploadLimit] | None = await delay_cancellation( callback(user_id) ) if res is not None: # to allow [] to be returned meaning no limit diff --git a/synapse/module_api/callbacks/ratelimit_callbacks.py b/synapse/module_api/callbacks/ratelimit_callbacks.py index 6afcda1216..0f4080dcd6 100644 --- a/synapse/module_api/callbacks/ratelimit_callbacks.py +++ b/synapse/module_api/callbacks/ratelimit_callbacks.py @@ -13,7 +13,7 @@ # import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Awaitable, Callable import attr @@ -37,7 +37,7 @@ class RatelimitOverride: GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK = Callable[ - [str, str], Awaitable[Optional[RatelimitOverride]] + [str, str], Awaitable[RatelimitOverride | None] ] @@ -51,9 +51,8 @@ class RatelimitModuleApiCallbacks: def register_callbacks( self, - get_ratelimit_override_for_user: Optional[ - GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK - ] = None, + get_ratelimit_override_for_user: GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK + | None = None, ) -> None: """Register callbacks from module for each hook.""" if get_ratelimit_override_for_user is not None: @@ -63,14 +62,14 @@ class RatelimitModuleApiCallbacks: async def get_ratelimit_override_for_user( self, user_id: str, limiter_name: str - ) -> Optional[RatelimitOverride]: + ) -> RatelimitOverride | None: for callback in self._get_ratelimit_override_for_user_callbacks: with Measure( self.clock, name=f"{callback.__module__}.{callback.__qualname__}", server_name=self.server_name, ): - res: Optional[RatelimitOverride] = await delay_cancellation( + res: RatelimitOverride | None = await delay_cancellation( callback(user_id, limiter_name) ) if res: diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py index 4c331c4210..8b34f7ef6c 100644 --- a/synapse/module_api/callbacks/spamchecker_callbacks.py +++ b/synapse/module_api/callbacks/spamchecker_callbacks.py @@ -30,8 +30,6 @@ from typing import ( Callable, Collection, Literal, - Optional, - Union, cast, ) @@ -53,210 +51,96 @@ logger = logging.getLogger(__name__) CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ["synapse.events.EventBase"], - Awaitable[ - Union[ - str, - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[str | Codes | tuple[Codes, JsonDict] | bool], ] SHOULD_DROP_FEDERATED_EVENT_CALLBACK = Callable[ ["synapse.events.EventBase"], - Awaitable[Union[bool, str]], + Awaitable[bool | str], ] USER_MAY_JOIN_ROOM_CALLBACK = Callable[ [str, str, bool], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict] | bool], ] USER_MAY_INVITE_CALLBACK = Callable[ [str, str, str], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict] | bool], ] FEDERATED_USER_MAY_INVITE_CALLBACK = Callable[ ["synapse.events.EventBase"], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict] | bool], ] USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[ [str, str, str, str], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict] | bool], ] -USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE = Union[ - Literal["NOT_SPAM"], - Codes, +USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE = ( + Literal["NOT_SPAM"] + | Codes + | # Highly experimental, not officially part of the spamchecker API, may # disappear without warning depending on the results of ongoing # experiments. # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], + tuple[Codes, JsonDict] + | # Deprecated - bool, -] -USER_MAY_CREATE_ROOM_CALLBACK = Union[ + bool +) +USER_MAY_CREATE_ROOM_CALLBACK = ( Callable[ [str, JsonDict], Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE], - ], - Callable[ # Single argument variant for backwards compatibility + ] + | Callable[ # Single argument variant for backwards compatibility [str], Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE] - ], -] + ] +) USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[ [str, RoomAlias], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict] | bool], ] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[ [str, str], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict] | bool], ] USER_MAY_SEND_STATE_EVENT_CALLBACK = Callable[ [str, str, str, str, JsonDict], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - ] - ], -] -CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[ - Callable[[UserProfile], Awaitable[bool]], - Callable[[UserProfile, str], Awaitable[bool]], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict]], ] +CHECK_USERNAME_FOR_SPAM_CALLBACK = ( + Callable[[UserProfile], Awaitable[bool]] + | Callable[[UserProfile, str], Awaitable[bool]] +) LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ - Optional[dict], - Optional[str], + dict | None, + str | None, Collection[tuple[str, str]], ], Awaitable[RegistrationBehaviour], ] CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ - Optional[dict], - Optional[str], + dict | None, + str | None, Collection[tuple[str, str]], - Optional[str], + str | None, ], Awaitable[RegistrationBehaviour], ] CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[ [ReadableFileWrapper, FileInfo], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - # Deprecated - bool, - ] - ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict] | bool], ] CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[ [ str, - Optional[str], - Optional[str], - Collection[tuple[Optional[str], str]], - Optional[str], - ], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - tuple[Codes, JsonDict], - ] + str | None, + str | None, + Collection[tuple[str | None, str]], + str | None, ], + Awaitable[Literal["NOT_SPAM"] | Codes | tuple[Codes, JsonDict]], ] @@ -292,7 +176,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: for spam_checker in spam_checkers: # Methods on legacy spam checkers might not be async, so we wrap them around a # wrapper that will call maybe_awaitable on the result. - def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + def async_wrapper(f: Callable | None) -> Callable[..., Awaitable] | None: # f might be None if the callback isn't implemented by the module. In this # case we don't want to register a callback at all so we return None. if f is None: @@ -308,11 +192,11 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: # that gives it only 3 arguments and drops the auth_provider_id on # the floor. def wrapper( - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], - auth_provider_id: Optional[str], - ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]: + auth_provider_id: str | None, + ) -> Awaitable[RegistrationBehaviour] | RegistrationBehaviour: # Assertion required because mypy can't prove we won't # change `f` back to `None`. See # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions @@ -390,26 +274,20 @@ class SpamCheckerModuleApiCallbacks: def register_callbacks( self, - check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, - should_drop_federated_event: Optional[ - SHOULD_DROP_FEDERATED_EVENT_CALLBACK - ] = None, - user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, - user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, - federated_user_may_invite: Optional[FEDERATED_USER_MAY_INVITE_CALLBACK] = None, - user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, - user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, - user_may_create_room_alias: Optional[ - USER_MAY_CREATE_ROOM_ALIAS_CALLBACK - ] = None, - user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None, - check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None, - check_registration_for_spam: Optional[ - CHECK_REGISTRATION_FOR_SPAM_CALLBACK - ] = None, - check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, - check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None, - user_may_send_state_event: Optional[USER_MAY_SEND_STATE_EVENT_CALLBACK] = None, + check_event_for_spam: CHECK_EVENT_FOR_SPAM_CALLBACK | None = None, + should_drop_federated_event: SHOULD_DROP_FEDERATED_EVENT_CALLBACK | None = None, + user_may_join_room: USER_MAY_JOIN_ROOM_CALLBACK | None = None, + user_may_invite: USER_MAY_INVITE_CALLBACK | None = None, + federated_user_may_invite: FEDERATED_USER_MAY_INVITE_CALLBACK | None = None, + user_may_send_3pid_invite: USER_MAY_SEND_3PID_INVITE_CALLBACK | None = None, + user_may_create_room: USER_MAY_CREATE_ROOM_CALLBACK | None = None, + user_may_create_room_alias: USER_MAY_CREATE_ROOM_ALIAS_CALLBACK | None = None, + user_may_publish_room: USER_MAY_PUBLISH_ROOM_CALLBACK | None = None, + check_username_for_spam: CHECK_USERNAME_FOR_SPAM_CALLBACK | None = None, + check_registration_for_spam: CHECK_REGISTRATION_FOR_SPAM_CALLBACK | None = None, + check_media_file_for_spam: CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK | None = None, + check_login_for_spam: CHECK_LOGIN_FOR_SPAM_CALLBACK | None = None, + user_may_send_state_event: USER_MAY_SEND_STATE_EVENT_CALLBACK | None = None, ) -> None: """Register callbacks from module for each hook.""" if check_event_for_spam is not None: @@ -469,7 +347,7 @@ class SpamCheckerModuleApiCallbacks: @trace async def check_event_for_spam( self, event: "synapse.events.EventBase" - ) -> Union[tuple[Codes, JsonDict], str]: + ) -> tuple[Codes, JsonDict] | str: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -532,7 +410,7 @@ class SpamCheckerModuleApiCallbacks: async def should_drop_federated_event( self, event: "synapse.events.EventBase" - ) -> Union[bool, str]: + ) -> bool | str: """Checks if a given federated event is considered "spammy" by this server. @@ -551,7 +429,7 @@ class SpamCheckerModuleApiCallbacks: name=f"{callback.__module__}.{callback.__qualname__}", server_name=self.server_name, ): - res: Union[bool, str] = await delay_cancellation(callback(event)) + res: bool | str = await delay_cancellation(callback(event)) if res: return res @@ -559,7 +437,7 @@ class SpamCheckerModuleApiCallbacks: async def user_may_join_room( self, user_id: str, room_id: str, is_invited: bool - ) -> Union[tuple[Codes, JsonDict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, JsonDict] | Literal["NOT_SPAM"]: """Checks if a given users is allowed to join a room. Not called when a user creates a room. @@ -603,7 +481,7 @@ class SpamCheckerModuleApiCallbacks: async def user_may_invite( self, inviter_userid: str, invitee_userid: str, room_id: str - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a given user may send an invite Args: @@ -648,7 +526,7 @@ class SpamCheckerModuleApiCallbacks: async def federated_user_may_invite( self, event: "synapse.events.EventBase" - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a given user may send an invite Args: @@ -689,7 +567,7 @@ class SpamCheckerModuleApiCallbacks: async def user_may_send_3pid_invite( self, inviter_userid: str, medium: str, address: str, room_id: str - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a given user may invite a given threepid into the room Note that if the threepid is already associated with a Matrix user ID, Synapse @@ -737,7 +615,7 @@ class SpamCheckerModuleApiCallbacks: async def user_may_create_room( self, userid: str, room_config: JsonDict - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a given user may create a room Args: @@ -803,7 +681,7 @@ class SpamCheckerModuleApiCallbacks: event_type: str, state_key: str, content: JsonDict, - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a given user may create a room with a given visibility Args: user_id: The ID of the user attempting to create a room @@ -836,7 +714,7 @@ class SpamCheckerModuleApiCallbacks: async def user_may_create_room_alias( self, userid: str, room_alias: RoomAlias - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a given user may create a room alias Args: @@ -874,7 +752,7 @@ class SpamCheckerModuleApiCallbacks: async def user_may_publish_room( self, userid: str, room_id: str - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a given user may publish a room to the directory Args: @@ -960,10 +838,10 @@ class SpamCheckerModuleApiCallbacks: async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], - auth_provider_id: Optional[str] = None, + auth_provider_id: str | None = None, ) -> RegistrationBehaviour: """Checks if we should allow the given registration request. @@ -998,7 +876,7 @@ class SpamCheckerModuleApiCallbacks: @trace async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if a piece of newly uploaded media should be blocked. This will be called for local uploads, downloads of remote media, each @@ -1011,7 +889,7 @@ class SpamCheckerModuleApiCallbacks: async def check_media_file_for_spam( self, file: ReadableFileWrapper, file_info: FileInfo - ) -> Union[Codes, Literal["NOT_SPAM"]]: + ) -> Codes | Literal["NOT_SPAM"]: buffer = BytesIO() await file.write_chunks_to(buffer.write) @@ -1058,11 +936,11 @@ class SpamCheckerModuleApiCallbacks: async def check_login_for_spam( self, user_id: str, - device_id: Optional[str], - initial_display_name: Optional[str], - request_info: Collection[tuple[Optional[str], str]], - auth_provider_id: Optional[str] = None, - ) -> Union[tuple[Codes, dict], Literal["NOT_SPAM"]]: + device_id: str | None, + initial_display_name: str | None, + request_info: Collection[tuple[str | None, str]], + auth_provider_id: str | None = None, + ) -> tuple[Codes, dict] | Literal["NOT_SPAM"]: """Checks if we should allow the given registration request. Args: diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py index 2b886cbabb..65f5a6b183 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py @@ -19,7 +19,7 @@ # # import logging -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable from twisted.internet.defer import CancelledError @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) CHECK_EVENT_ALLOWED_CALLBACK = Callable[ - [EventBase, StateMap[EventBase]], Awaitable[tuple[bool, Optional[dict]]] + [EventBase, StateMap[EventBase]], Awaitable[tuple[bool, dict | None]] ] ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ @@ -47,7 +47,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] -CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[Optional[str], str], Awaitable[bool]] +CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str | None, str], Awaitable[bool]] CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] @@ -77,7 +77,7 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: "check_visibility_can_be_modified", } - def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + def async_wrapper(f: Callable | None) -> Callable[..., Awaitable] | None: # f might be None if the callback isn't implemented by the module. In this # case we don't want to register a callback at all so we return None. if f is None: @@ -93,7 +93,7 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: async def wrap_check_event_allowed( event: EventBase, state_events: StateMap[EventBase], - ) -> tuple[bool, Optional[dict]]: + ) -> tuple[bool, dict | None]: # Assertion required because mypy can't prove we won't change # `f` back to `None`. See # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions @@ -188,28 +188,23 @@ class ThirdPartyEventRulesModuleApiCallbacks: def register_third_party_rules_callbacks( self, - check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, - on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, - check_threepid_can_be_invited: Optional[ - CHECK_THREEPID_CAN_BE_INVITED_CALLBACK - ] = None, - check_visibility_can_be_modified: Optional[ - CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK - ] = None, - on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, - check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, - check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, - on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, - on_user_deactivation_status_changed: Optional[ - ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK - ] = None, - on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, - on_add_user_third_party_identifier: Optional[ - ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, - on_remove_user_third_party_identifier: Optional[ - ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, + check_event_allowed: CHECK_EVENT_ALLOWED_CALLBACK | None = None, + on_create_room: ON_CREATE_ROOM_CALLBACK | None = None, + check_threepid_can_be_invited: CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + | None = None, + check_visibility_can_be_modified: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + | None = None, + on_new_event: ON_NEW_EVENT_CALLBACK | None = None, + check_can_shutdown_room: CHECK_CAN_SHUTDOWN_ROOM_CALLBACK | None = None, + check_can_deactivate_user: CHECK_CAN_DEACTIVATE_USER_CALLBACK | None = None, + on_profile_update: ON_PROFILE_UPDATE_CALLBACK | None = None, + on_user_deactivation_status_changed: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK + | None = None, + on_threepid_bind: ON_THREEPID_BIND_CALLBACK | None = None, + on_add_user_third_party_identifier: ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + | None = None, + on_remove_user_third_party_identifier: ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + | None = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -261,7 +256,7 @@ class ThirdPartyEventRulesModuleApiCallbacks: self, event: EventBase, context: UnpersistedEventContextBase, - ) -> tuple[bool, Optional[dict]]: + ) -> tuple[bool, dict | None]: """Check if a provided event should be allowed in the given context. The module can return: @@ -443,9 +438,7 @@ class ThirdPartyEventRulesModuleApiCallbacks: "Failed to run module API callback %s: %s", callback, e ) - async def check_can_shutdown_room( - self, user_id: Optional[str], room_id: str - ) -> bool: + async def check_can_shutdown_room(self, user_id: str | None, room_id: str) -> bool: """Intercept requests to shutdown a room. If `False` is returned, the room must not be shut down. diff --git a/synapse/notifier.py b/synapse/notifier.py index 4a75d07e37..260a2c0d87 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -28,9 +28,7 @@ from typing import ( Iterable, Literal, Mapping, - Optional, TypeVar, - Union, overload, ) @@ -211,7 +209,7 @@ class _NotifierUserStream: @attr.s(slots=True, frozen=True, auto_attribs=True) class EventStreamResult: - events: list[Union[JsonDict, EventBase]] + events: list[JsonDict | EventBase] start_token: StreamToken end_token: StreamToken @@ -226,8 +224,8 @@ class _PendingRoomEventEntry: room_id: str type: str - state_key: Optional[str] - membership: Optional[str] + state_key: str | None + membership: str | None class Notifier: @@ -336,7 +334,7 @@ class Notifier: self, events_and_pos: list[tuple[EventBase, PersistedEventPosition]], max_room_stream_token: RoomStreamToken, - extra_users: Optional[Collection[UserID]] = None, + extra_users: Collection[UserID] | None = None, ) -> None: """Creates a _PendingRoomEventEntry for each of the listed events and calls notify_new_room_events with the results.""" @@ -421,11 +419,11 @@ class Notifier: def create_pending_room_event_entry( self, event_pos: PersistedEventPosition, - extra_users: Optional[Collection[UserID]], + extra_users: Collection[UserID] | None, room_id: str, event_type: str, - state_key: Optional[str], - membership: Optional[str], + state_key: str | None, + membership: str | None, ) -> _PendingRoomEventEntry: """Creates and returns a _PendingRoomEventEntry""" return _PendingRoomEventEntry( @@ -504,8 +502,8 @@ class Notifier: self, stream_key: Literal[StreamKeyType.ROOM], new_token: RoomStreamToken, - users: Optional[Collection[Union[str, UserID]]] = None, - rooms: Optional[StrCollection] = None, + users: Collection[str | UserID] | None = None, + rooms: StrCollection | None = None, ) -> None: ... @overload @@ -513,8 +511,8 @@ class Notifier: self, stream_key: Literal[StreamKeyType.RECEIPT], new_token: MultiWriterStreamToken, - users: Optional[Collection[Union[str, UserID]]] = None, - rooms: Optional[StrCollection] = None, + users: Collection[str | UserID] | None = None, + rooms: StrCollection | None = None, ) -> None: ... @overload @@ -531,16 +529,16 @@ class Notifier: StreamKeyType.THREAD_SUBSCRIPTIONS, ], new_token: int, - users: Optional[Collection[Union[str, UserID]]] = None, - rooms: Optional[StrCollection] = None, + users: Collection[str | UserID] | None = None, + rooms: StrCollection | None = None, ) -> None: ... def on_new_event( self, stream_key: StreamKeyType, - new_token: Union[int, RoomStreamToken, MultiWriterStreamToken], - users: Optional[Collection[Union[str, UserID]]] = None, - rooms: Optional[StrCollection] = None, + new_token: int | RoomStreamToken | MultiWriterStreamToken, + users: Collection[str | UserID] | None = None, + rooms: StrCollection | None = None, ) -> None: """Used to inform listeners that something has happened event wise. @@ -636,7 +634,7 @@ class Notifier: user_id: str, timeout: int, callback: Callable[[StreamToken, StreamToken], Awaitable[T]], - room_ids: Optional[StrCollection] = None, + room_ids: StrCollection | None = None, from_token: StreamToken = StreamToken.START, ) -> T: """Wait until the callback returns a non empty response or the @@ -737,7 +735,7 @@ class Notifier: pagination_config: PaginationConfig, timeout: int, is_guest: bool = False, - explicit_room_id: Optional[str] = None, + explicit_room_id: str | None = None, ) -> EventStreamResult: """For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any @@ -767,7 +765,7 @@ class Notifier: # The events fetched from each source are a JsonDict, EventBase, or # UserPresenceState, but see below for UserPresenceState being # converted to JsonDict. - events: list[Union[JsonDict, EventBase]] = [] + events: list[JsonDict | EventBase] = [] end_token = from_token for keyname, source in self.event_sources.sources.get_sources(): @@ -866,7 +864,7 @@ class Notifier: await self.clock.sleep(0.5) async def _get_room_ids( - self, user: UserID, explicit_room_id: Optional[str] + self, user: UserID, explicit_room_id: str | None ) -> tuple[StrCollection, bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 552af8e14a..58c58fbdbf 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -94,7 +94,7 @@ The Pusher instance also calls out to various utilities for generating payloads """ import abc -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import attr @@ -108,7 +108,7 @@ if TYPE_CHECKING: class PusherConfig: """Parameters necessary to configure a pusher.""" - id: Optional[int] + id: int | None user_name: str profile_tag: str @@ -118,18 +118,18 @@ class PusherConfig: device_display_name: str pushkey: str ts: int - lang: Optional[str] - data: Optional[JsonDict] + lang: str | None + data: JsonDict | None last_stream_ordering: int - last_success: Optional[int] - failing_since: Optional[int] + last_success: int | None + failing_since: int | None enabled: bool - device_id: Optional[str] + device_id: str | None # XXX(quenting): The access_token is not persisted anymore for new pushers, but we # keep it when reading from the database, so that we don't get stale pushers # while the "set_device_id_for_pushers" background update is running. - access_token: Optional[int] + access_token: int | None def as_dict(self) -> dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9fcd7fdc6e..7cf89200a8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -25,9 +25,7 @@ from typing import ( Any, Collection, Mapping, - Optional, Sequence, - Union, cast, ) @@ -233,7 +231,7 @@ class BulkPushRuleEvaluator: event: EventBase, context: EventContext, event_id_to_event: Mapping[str, EventBase], - ) -> tuple[dict, Optional[int]]: + ) -> tuple[dict, int | None]: """ Given an event and an event context, get the power level event relevant to the event and the power level of the sender of the event. @@ -390,7 +388,7 @@ class BulkPushRuleEvaluator: count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event) - actions_by_user: dict[str, Collection[Union[Mapping, str]]] = {} + actions_by_user: dict[str, Collection[Mapping | str]] = {} # Gather a bunch of info in parallel. # @@ -405,7 +403,7 @@ class BulkPushRuleEvaluator: profiles, ) = await make_deferred_yieldable( cast( - "Deferred[tuple[int, tuple[dict, Optional[int]], dict[str, dict[str, JsonValue]], Mapping[str, ProfileInfo]]]", + "Deferred[tuple[int, tuple[dict, int | None], dict[str, dict[str, JsonValue]], Mapping[str, ProfileInfo]]]", gather_results( ( run_in_background( # type: ignore[call-overload] @@ -477,7 +475,7 @@ class BulkPushRuleEvaluator: self.hs.config.experimental.msc4306_enabled, ) - msc4306_thread_subscribers: Optional[frozenset[str]] = None + msc4306_thread_subscribers: frozenset[str] | None = None if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE: # pull out, in batch, all local subscribers to this thread # (in the common case, they will all be getting processed for push @@ -510,7 +508,7 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - msc4306_thread_subscription_state: Optional[bool] = None + msc4306_thread_subscription_state: bool | None = None if msc4306_thread_subscribers is not None: msc4306_thread_subscription_state = uid in msc4306_thread_subscribers @@ -552,10 +550,10 @@ class BulkPushRuleEvaluator: ) -MemberMap = dict[str, Optional[EventIdMembership]] +MemberMap = dict[str, EventIdMembership | None] Rule = dict[str, dict] RulesByUser = dict[str, list[Rule]] -StateGroup = Union[object, int] +StateGroup = object | int def _is_simple_value(value: Any) -> bool: @@ -567,9 +565,9 @@ def _is_simple_value(value: Any) -> bool: def _flatten_dict( - d: Union[EventBase, Mapping[str, Any]], - prefix: Optional[list[str]] = None, - result: Optional[dict[str, JsonValue]] = None, + d: EventBase | Mapping[str, Any], + prefix: list[str] | None = None, + result: dict[str, JsonValue] | None = None, ) -> dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index fd1758db9d..db082e295d 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -20,7 +20,7 @@ # import copy -from typing import Any, Optional +from typing import Any from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP from synapse.synapse_rust.push import FilteredPushRules, PushRule @@ -85,7 +85,7 @@ def _add_empty_priority_class_arrays(d: dict[str, list]) -> dict[str, list]: return d -def _rule_to_template(rule: PushRule) -> Optional[dict[str, Any]]: +def _rule_to_template(rule: PushRule) -> dict[str, Any] | None: templaterule: dict[str, Any] unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 83823c2284..36dc9bf6fc 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.interfaces import IDelayedCall @@ -70,7 +70,7 @@ class EmailPusher(Pusher): self.server_name = hs.hostname self.store = self.hs.get_datastores().main self.email = pusher_config.pushkey - self.timed_call: Optional[IDelayedCall] = None + self.timed_call: IDelayedCall | None = None self.throttle_params: dict[str, ThrottleParams] = {} self._inited = False @@ -174,7 +174,7 @@ class EmailPusher(Pusher): ) ) - soonest_due_at: Optional[int] = None + soonest_due_at: int | None = None if not unprocessed: await self.save_last_stream_ordering_and_success(self.max_stream_ordering) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 8df106b859..edcabf0c29 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -21,7 +21,7 @@ import logging import random import urllib.parse -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from prometheus_client import Counter @@ -68,7 +68,7 @@ http_badges_failed_counter = Counter( ) -def tweaks_for_actions(actions: list[Union[str, dict]]) -> JsonMapping: +def tweaks_for_actions(actions: list[str | dict]) -> JsonMapping: """ Converts a list of actions into a `tweaks` dict (which can then be passed to the push gateway). @@ -119,7 +119,7 @@ class HttpPusher(Pusher): self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.failing_since = pusher_config.failing_since - self.timed_call: Optional[IDelayedCall] = None + self.timed_call: IDelayedCall | None = None self._is_processing = False self._group_unread_count_by_room = ( hs.config.push.push_group_unread_count_by_room @@ -163,7 +163,7 @@ class HttpPusher(Pusher): self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] - self.badge_count_last_call: Optional[int] = None + self.badge_count_last_call: int | None = None def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. @@ -394,9 +394,9 @@ class HttpPusher(Pusher): async def dispatch_push( self, content: JsonDict, - tweaks: Optional[JsonMapping] = None, - default_payload: Optional[JsonMapping] = None, - ) -> Union[bool, list[str]]: + tweaks: JsonMapping | None = None, + default_payload: JsonMapping | None = None, + ) -> bool | list[str]: """Send a notification to the registered push gateway, with `content` being the content of the `notification` top property specified in the spec. Note that the `devices` property will be added with device-specific @@ -453,7 +453,7 @@ class HttpPusher(Pusher): event: EventBase, tweaks: JsonMapping, badge: int, - ) -> Union[bool, list[str]]: + ) -> bool | list[str]: """Send a notification to the registered push gateway by building it from an event. diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3dac61aed5..6492207403 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -21,7 +21,7 @@ import logging import urllib.parse -from typing import TYPE_CHECKING, Iterable, Optional, TypeVar +from typing import TYPE_CHECKING, Iterable, TypeVar import bleach import jinja2 @@ -372,7 +372,7 @@ class Mailer: email_address: str, subject: str, extra_template_vars: TemplateVars, - unsubscribe_link: Optional[str] = None, + unsubscribe_link: str | None = None, ) -> None: """Send an email with the given information and template text""" template_vars: TemplateVars = { @@ -486,7 +486,7 @@ class Mailer: async def _get_room_avatar( self, room_state_ids: StateMap[str], - ) -> Optional[str]: + ) -> str | None: """ Retrieve the avatar url for this room---if it exists. @@ -553,7 +553,7 @@ class Mailer: async def _get_message_vars( self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str] - ) -> Optional[MessageVars]: + ) -> MessageVars | None: """ Generate the variables for a single event, if possible. @@ -573,7 +573,7 @@ class Mailer: type_state_key = ("m.room.member", event.sender) sender_state_event_id = room_state_ids.get(type_state_key) if sender_state_event_id: - sender_state_event: Optional[EventBase] = await self.store.get_event( + sender_state_event: EventBase | None = await self.store.get_event( sender_state_event_id ) else: @@ -585,9 +585,7 @@ class Mailer: if sender_state_event: sender_name = name_from_member_event(sender_state_event) - sender_avatar_url: Optional[str] = sender_state_event.content.get( - "avatar_url" - ) + sender_avatar_url: str | None = sender_state_event.content.get("avatar_url") else: # No state could be found, fallback to the MXID. sender_name = event.sender diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 2f32e18b9a..d8000dd607 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -21,7 +21,7 @@ import logging import re -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -45,7 +45,7 @@ async def calculate_room_name( user_id: str, fallback_to_members: bool = True, fallback_to_single_member: bool = True, -) -> Optional[str]: +) -> str | None: """ Works out a user-facing name for the given room as per Matrix spec recommendations. diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py index e1678cd717..7553b4bf10 100644 --- a/synapse/push/push_types.py +++ b/synapse/push/push_types.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional, TypedDict +from typing import TypedDict class EmailReason(TypedDict, total=False): @@ -40,7 +40,7 @@ class EmailReason(TypedDict, total=False): room_id: str now: int - room_name: Optional[str] + room_name: str | None received_at: int delay_before_mail_ms: int last_sent_ts: int @@ -71,9 +71,9 @@ class MessageVars(TypedDict, total=False): id: str ts: int sender_name: str - sender_avatar_url: Optional[str] + sender_avatar_url: str | None sender_hash: int - msgtype: Optional[str] + msgtype: str | None body_text_html: str body_text_plain: str image_url: str @@ -90,7 +90,7 @@ class NotifVars(TypedDict): """ link: str - ts: Optional[int] + ts: int | None messages: list[MessageVars] @@ -107,12 +107,12 @@ class RoomVars(TypedDict): avator_url: url to the room's avator """ - title: Optional[str] + title: str | None hash: int invite: bool notifs: list[NotifVars] link: str - avatar_url: Optional[str] + avatar_url: str | None class TemplateVars(TypedDict, total=False): diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 17238c95c0..948465cad1 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable from synapse.push import Pusher, PusherConfig from synapse.push.emailpusher import EmailPusher @@ -53,7 +53,7 @@ class PusherFactory: logger.info("defined email pusher type") - def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: + def create_pusher(self, pusher_config: PusherConfig) -> Pusher | None: kind = pusher_config.kind f = self.pusher_types.get(kind, None) if not f: diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 6b70de976a..7b5b06db83 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable from prometheus_client import Gauge @@ -119,12 +119,12 @@ class PusherPool: app_display_name: str, device_display_name: str, pushkey: str, - lang: Optional[str], + lang: str | None, data: JsonDict, profile_tag: str = "", enabled: bool = True, - device_id: Optional[str] = None, - ) -> Optional[Pusher]: + device_id: str | None = None, + ) -> Pusher | None: """Creates a new pusher and adds it to the pool Returns: @@ -330,7 +330,7 @@ class PusherPool: async def _get_pusher_config_for_user_by_app_id_and_pushkey( self, user_id: str, app_id: str, pushkey: str - ) -> Optional[PusherConfig]: + ) -> PusherConfig | None: resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_config = None @@ -342,7 +342,7 @@ class PusherPool: async def process_pusher_change_by_id( self, app_id: str, pushkey: str, user_id: str - ) -> Optional[Pusher]: + ) -> Pusher | None: """Look up the details for the given pusher, and either start it if its "enabled" flag is True, or try to stop it otherwise. @@ -381,7 +381,7 @@ class PusherPool: logger.info("Started pushers") - async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: + async def _start_pusher(self, pusher_config: PusherConfig) -> Pusher | None: """Start the given pusher Args: diff --git a/synapse/replication/http/delayed_events.py b/synapse/replication/http/delayed_events.py index e448ac32bf..26eaf68dae 100644 --- a/synapse/replication/http/delayed_events.py +++ b/synapse/replication/http/delayed_events.py @@ -13,7 +13,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.web.server import Request @@ -52,7 +52,7 @@ class ReplicationAddedDelayedEventRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict - ) -> tuple[int, dict[str, Optional[JsonMapping]]]: + ) -> tuple[int, dict[str, JsonMapping | None]]: self.handler.on_added(int(content["next_send_ts"])) return 200, {} diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 2fadee8a06..7a11537f9e 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -19,7 +19,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.web.server import Request @@ -170,7 +170,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict - ) -> tuple[int, dict[str, Optional[JsonMapping]]]: + ) -> tuple[int, dict[str, JsonMapping | None]]: user_ids: list[str] = content["user_ids"] logger.info("Resync for %r", user_ids) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 0022e12eac..fc21c20ca2 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -19,7 +19,7 @@ # import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from twisted.web.server import Request @@ -50,13 +50,13 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, - device_id: Optional[str], - initial_display_name: Optional[str], + device_id: str | None, + initial_display_name: str | None, is_guest: bool, is_appservice_ghost: bool, should_issue_refresh_token: bool, - auth_provider_id: Optional[str], - auth_provider_session_id: Optional[str], + auth_provider_id: str | None, + auth_provider_session_id: str | None, ) -> JsonDict: """ Args: diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 0e588037b6..8a6c971720 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -18,7 +18,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.web.server import Request @@ -192,7 +192,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] invite_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> JsonDict: @@ -260,7 +260,7 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] knock_event_id: str, - txn_id: Optional[str], + txn_id: str | None, requester: Requester, content: JsonDict, ) -> JsonDict: diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index 4a894b0221..960f0485ff 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.web.server import Request @@ -58,7 +58,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict: # type: ignore[override] + async def _serialize_payload(user_id: str, device_id: str | None) -> JsonDict: # type: ignore[override] return {"device_id": device_id} async def _handle_request( # type: ignore[override] @@ -102,7 +102,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, - device_id: Optional[str], + device_id: str | None, state: JsonDict, force_notify: bool = False, is_sync: bool = False, diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 780fcc463a..bd83b38c96 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -19,7 +19,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.web.server import Request @@ -59,14 +59,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, - password_hash: Optional[str], + password_hash: str | None, was_guest: bool, make_guest: bool, - appservice_id: Optional[str], - create_profile_with_displayname: Optional[str], + appservice_id: str | None, + create_profile_with_displayname: str | None, admin: bool, - user_type: Optional[str], - address: Optional[str], + user_type: str | None, + address: str | None, shadow_banned: bool, approved: bool, ) -> JsonDict: @@ -143,7 +143,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] - user_id: str, auth_result: JsonDict, access_token: Optional[str] + user_id: str, auth_result: JsonDict, access_token: str | None ) -> JsonDict: """ Args: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index f9605407af..297feb0049 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -21,7 +21,7 @@ """A replication client for use by synapse workers.""" import logging -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable from sortedcontainers import SortedList @@ -89,7 +89,7 @@ class ReplicationDataHandler: self._pusher_pool = hs.get_pusherpool() self._presence_handler = hs.get_presence_handler() - self.send_handler: Optional[FederationSenderHandler] = None + self.send_handler: FederationSenderHandler | None = None if hs.should_send_federation(): self.send_handler = FederationSenderHandler(hs) @@ -435,7 +435,7 @@ class FederationSenderHandler: # Stores the latest position in the federation stream we've gotten up # to. This is always set before we use it. - self.federation_position: Optional[int] = None + self.federation_position: int | None = None self._fed_position_linearizer = Linearizer( name="_fed_position_linearizer", clock=hs.get_clock() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f115cc4db9..79194f7275 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -26,7 +26,7 @@ allowed to be sent by which side. import abc import logging -from typing import Optional, TypeVar +from typing import TypeVar from synapse.replication.tcp.streams._base import StreamRow from synapse.util.json import json_decoder, json_encoder @@ -137,7 +137,7 @@ class RdataCommand(Command): NAME = "RDATA" def __init__( - self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow + self, stream_name: str, instance_name: str, token: int | None, row: StreamRow ): self.stream_name = stream_name self.instance_name = instance_name @@ -288,7 +288,7 @@ class UserSyncCommand(Command): self, instance_id: str, user_id: str, - device_id: Optional[str], + device_id: str | None, is_syncing: bool, last_sync_ms: int, ): @@ -300,7 +300,7 @@ class UserSyncCommand(Command): @classmethod def from_line(cls: type["UserSyncCommand"], line: str) -> "UserSyncCommand": - device_id: Optional[str] + device_id: str | None instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4) if device_id == "None": @@ -407,7 +407,7 @@ class UserIpCommand(Command): access_token: str, ip: str, user_agent: str, - device_id: Optional[str], + device_id: str | None, last_seen: int, ): self.user_id = user_id diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index bcdd55d2e6..ca959a7aae 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from prometheus_client import Counter, Histogram @@ -73,7 +73,7 @@ class ExternalCache: self.server_name = hs.hostname if hs.config.redis.redis_enabled: - self._redis_connection: Optional["ConnectionHandler"] = ( + self._redis_connection: "ConnectionHandler" | None = ( hs.get_outbound_redis_connection() ) else: @@ -121,7 +121,7 @@ class ExternalCache: ) ) - async def get(self, cache_name: str, key: str) -> Optional[Any]: + async def get(self, cache_name: str, key: str) -> Any | None: """Look up a key/value in the named cache.""" if self._redis_connection is None: diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 8cf7f4b805..05370045e6 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -26,9 +26,7 @@ from typing import ( Awaitable, Iterable, Iterator, - Optional, TypeVar, - Union, ) from prometheus_client import Counter @@ -115,9 +113,7 @@ tcp_command_queue_gauge = LaterGauge( # the type of the entries in _command_queues_by_stream -_StreamCommandQueueItem = tuple[ - Union[RdataCommand, PositionCommand], IReplicationConnection -] +_StreamCommandQueueItem = tuple[RdataCommand | PositionCommand, IReplicationConnection] class ReplicationCommandHandler: @@ -245,7 +241,7 @@ class ReplicationCommandHandler: self._pending_batches: dict[str, list[Any]] = {} # The factory used to create connections. - self._factory: Optional[ReconnectingClientFactory] = None + self._factory: ReconnectingClientFactory | None = None # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) @@ -341,7 +337,7 @@ class ReplicationCommandHandler: self._channels_to_subscribe_to.append(channel_name) def _add_command_to_stream_queue( - self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] + self, conn: IReplicationConnection, cmd: RdataCommand | PositionCommand ) -> None: """Queue the given received command for processing @@ -368,7 +364,7 @@ class ReplicationCommandHandler: async def _process_command( self, - cmd: Union[PositionCommand, RdataCommand], + cmd: PositionCommand | RdataCommand, conn: IReplicationConnection, stream_name: str, ) -> None: @@ -459,7 +455,7 @@ class ReplicationCommandHandler: def on_USER_SYNC( self, conn: IReplicationConnection, cmd: UserSyncCommand - ) -> Optional[Awaitable[None]]: + ) -> Awaitable[None] | None: user_sync_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() if self._is_presence_writer: @@ -475,7 +471,7 @@ class ReplicationCommandHandler: def on_CLEAR_USER_SYNC( self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand - ) -> Optional[Awaitable[None]]: + ) -> Awaitable[None] | None: if self._is_presence_writer: return self._presence_handler.update_external_syncs_clear(cmd.instance_id) else: @@ -491,7 +487,7 @@ class ReplicationCommandHandler: def on_USER_IP( self, conn: IReplicationConnection, cmd: UserIpCommand - ) -> Optional[Awaitable[None]]: + ) -> Awaitable[None] | None: user_ip_cache_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() if self._is_master or self._should_insert_client_ips: @@ -833,7 +829,7 @@ class ReplicationCommandHandler: self, instance_id: str, user_id: str, - device_id: Optional[str], + device_id: str | None, is_syncing: bool, last_sync_ms: int, ) -> None: @@ -848,7 +844,7 @@ class ReplicationCommandHandler: access_token: str, ip: str, user_agent: str, - device_id: Optional[str], + device_id: str | None, last_seen: int, ) -> None: """Tell the master that the user made a request.""" @@ -858,7 +854,7 @@ class ReplicationCommandHandler: def send_remote_server_up(self, server: str) -> None: self.send_command(RemoteServerUpCommand(server)) - def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None: + def stream_update(self, stream_name: str, token: int | None, data: Any) -> None: """Called when a new update is available to stream to Redis subscribers. We need to check if the client is interested in the stream or not diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 733643cb64..3068e60af0 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -28,7 +28,7 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Collection, Optional +from typing import TYPE_CHECKING, Any, Collection from prometheus_client import Counter from zope.interface import Interface, implementer @@ -153,7 +153,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 # When we requested the connection be closed - self.time_we_closed: Optional[int] = None + self.time_we_closed: int | None = None self.received_ping = False # Have we received a ping from the other side @@ -166,7 +166,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.pending_commands: list[Command] = [] # The LoopingCall for sending pings. - self._send_ping_loop: Optional[task.LoopingCall] = None + self._send_ping_loop: task.LoopingCall | None = None # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 4448117d62..27d43e6fba 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -21,7 +21,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast import attr from txredisapi import ( @@ -72,10 +72,10 @@ class ConstantProperty(Generic[T, V]): constant: V = attr.ib() - def __get__(self, obj: Optional[T], objtype: Optional[type[T]] = None) -> V: + def __get__(self, obj: T | None, objtype: type[T] | None = None) -> V: return self.constant - def __set__(self, obj: Optional[T], value: V) -> None: + def __set__(self, obj: T | None, value: V) -> None: pass @@ -119,7 +119,7 @@ class RedisSubscriber(SubscriberProtocol): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. - self._logging_context: Optional[BackgroundProcessLoggingContext] = None + self._logging_context: BackgroundProcessLoggingContext | None = None def _get_logging_context(self) -> BackgroundProcessLoggingContext: """ @@ -293,14 +293,14 @@ class SynapseRedisFactory(RedisFactory): self, hs: "HomeServer", uuid: str, - dbid: Optional[int], + dbid: int | None, poolsize: int, isLazy: bool = False, handler: type = ConnectionHandler, charset: str = "utf-8", - password: Optional[str] = None, + password: str | None = None, replyTimeout: int = 30, - convertNumbers: Optional[int] = True, + convertNumbers: int | None = True, ): super().__init__( uuid=uuid, @@ -422,9 +422,9 @@ def lazyConnection( hs: "HomeServer", host: str = "localhost", port: int = 6379, - dbid: Optional[int] = None, + dbid: int | None = None, reconnect: bool = True, - password: Optional[str] = None, + password: str | None = None, replyTimeout: int = 30, ) -> ConnectionHandler: """Creates a connection to Redis that is lazily set up and reconnects if the @@ -471,9 +471,9 @@ def lazyConnection( def lazyUnixConnection( hs: "HomeServer", path: str = "/tmp/redis.sock", - dbid: Optional[int] = None, + dbid: int | None = None, reconnect: bool = True, - password: Optional[str] = None, + password: str | None = None, replyTimeout: int = 30, ) -> ConnectionHandler: """Creates a connection to Redis that is lazily set up and reconnects if the diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 8df0a3853f..134d8d921f 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -22,7 +22,7 @@ import logging import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from prometheus_client import Counter @@ -321,7 +321,7 @@ class ReplicationStreamer: def _batch_updates( updates: list[tuple[Token, StreamRow]], -) -> list[tuple[Optional[Token], StreamRow]]: +) -> list[tuple[Token | None, StreamRow]]: """Takes a list of updates of form [(token, row)] and sets the token to None for all rows where the next row has the same token. This is used to implement batching. @@ -337,7 +337,7 @@ def _batch_updates( if not updates: return [] - new_updates: list[tuple[Optional[Token], StreamRow]] = [] + new_updates: list[tuple[Token | None, StreamRow]] = [] for i, update in enumerate(updates[:-1]): if update[0] == updates[i + 1][0]: new_updates.append((None, update[1])) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d80bdb9b35..4fb2aac202 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -26,7 +26,6 @@ from typing import ( Any, Awaitable, Callable, - Optional, TypeVar, ) @@ -285,9 +284,9 @@ class BackfillStream(Stream): event_id: str room_id: str type: str - state_key: Optional[str] - redacts: Optional[str] - relates_to: Optional[str] + state_key: str | None + redacts: str | None + relates_to: str | None NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -435,7 +434,7 @@ class ReceiptsStream(_StreamFromIdGen): receipt_type: str user_id: str event_id: str - thread_id: Optional[str] + thread_id: str | None data: dict NAME = "receipts" @@ -510,7 +509,7 @@ class CachesStream(Stream): """ cache_func: str - keys: Optional[list[Any]] + keys: list[Any] | None invalidation_ts: int NAME = "caches" @@ -639,7 +638,7 @@ class AccountDataStream(_StreamFromIdGen): @attr.s(slots=True, frozen=True, auto_attribs=True) class AccountDataStreamRow: user_id: str - room_id: Optional[str] + room_id: str | None data_type: str NAME = "account_data" diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index a6314b0c7d..ca9f6f12da 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -20,7 +20,7 @@ # import heapq from collections import defaultdict -from typing import TYPE_CHECKING, Iterable, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Iterable, TypeVar, cast import attr @@ -93,7 +93,7 @@ class BaseEventsStreamRow: TypeId: str @classmethod - def from_data(cls: type[T], data: Iterable[Optional[str]]) -> T: + def from_data(cls: type[T], data: Iterable[str | None]) -> T: """Parse the data from the replication stream into a row. By default we just call the constructor with the data list as arguments @@ -111,10 +111,10 @@ class EventsStreamEventRow(BaseEventsStreamRow): event_id: str room_id: str type: str - state_key: Optional[str] - redacts: Optional[str] - relates_to: Optional[str] - membership: Optional[str] + state_key: str | None + redacts: str | None + relates_to: str | None + membership: str | None rejected: bool outlier: bool @@ -126,7 +126,7 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): room_id: str type: str state_key: str - event_id: Optional[str] + event_id: str | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -282,6 +282,6 @@ class EventsStream(_StreamFromIdGen): @classmethod def parse_row(cls, row: StreamRow) -> "EventsStreamRow": - (typ, data) = cast(tuple[str, Iterable[Optional[str]]], row) + (typ, data) = cast(tuple[str, Iterable[str | None]], row) event_stream_row_data = TypeToRow[typ].from_data(data) return EventsStreamRow(typ, event_stream_row_data) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index ea0e47ded4..fe66494d82 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -19,7 +19,7 @@ # # import logging -from typing import TYPE_CHECKING, Callable, Iterable, Optional +from typing import TYPE_CHECKING, Callable, Iterable from synapse.http.server import HttpServer, JsonResource from synapse.rest import admin @@ -143,7 +143,7 @@ class ClientRestResource(JsonResource): * etc """ - def __init__(self, hs: "HomeServer", servlet_groups: Optional[list[str]] = None): + def __init__(self, hs: "HomeServer", servlet_groups: list[str] | None = None): JsonResource.__init__(self, hs, canonical_json=False) if hs.config.media.can_load_media_repo: # This import is here to prevent a circular import failure @@ -156,7 +156,7 @@ class ClientRestResource(JsonResource): def register_servlets( client_resource: HttpServer, hs: "HomeServer", - servlet_groups: Optional[Iterable[str]] = None, + servlet_groups: Iterable[str] | None = None, ) -> None: # Some servlets are only registered on the main process (and not worker # processes). diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index bcaba85da3..e34ebb17e6 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -35,7 +35,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.handlers.pagination import PURGE_HISTORY_ACTION_NAME @@ -153,7 +153,7 @@ class PurgeHistoryRestServlet(RestServlet): self.auth = hs.get_auth() async def on_POST( - self, request: SynapseRequest, room_id: str, event_id: Optional[str] + self, request: SynapseRequest, room_id: str, event_id: str | None ) -> tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -173,7 +173,7 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(HTTPStatus.BAD_REQUEST, "Event is for wrong room.") - # RoomStreamToken expects [int] not Optional[int] + # RoomStreamToken expects [int] not [int | None] assert event.internal_metadata.stream_ordering is not None room_token = RoomStreamToken( topological=event.depth, stream=event.internal_metadata.stream_ordering diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index cfdb314b1a..d5346fe0d5 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -20,7 +20,7 @@ # import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import attr @@ -374,7 +374,7 @@ class DeleteMediaByDateSize(RestServlet): self.media_repository = hs.get_media_repository() async def on_POST( - self, request: SynapseRequest, server_name: Optional[str] = None + self, request: SynapseRequest, server_name: str | None = None ) -> tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index e1bfca3c03..cf24bc628a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -20,7 +20,7 @@ # import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast import attr from immutabledict import immutabledict @@ -565,7 +565,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): # Get the room ID from the identifier. try: - remote_room_hosts: Optional[list[str]] = [ + remote_room_hosts: list[str] | None = [ x.decode("ascii") for x in request.args[b"server_name"] ] except Exception: diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 0be04c0f90..50d2f35b18 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -18,7 +18,7 @@ # # from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError, SynapseError @@ -80,7 +80,7 @@ class SendServerNoticeServlet(RestServlet): self, request: SynapseRequest, requester: Requester, - txn_id: Optional[str], + txn_id: str | None, ) -> tuple[int, JsonDict]: await assert_user_is_admin(self.auth, requester) body = parse_json_object_from_request(request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 3eab53e5a2..42e9f8043d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -23,7 +23,7 @@ import hmac import logging import secrets from http import HTTPStatus -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import attr from pydantic import StrictBool, StrictInt, StrictStr @@ -163,7 +163,7 @@ class UsersRestServletV2(RestServlet): direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + # twisted.web.server.Request.args is incorrectly defined as Any | None args: dict[bytes, list[bytes]] = request.args # type: ignore not_user_types = parse_strings_from_args(args, "not_user_type") @@ -195,7 +195,7 @@ class UsersRestServletV2(RestServlet): return HTTPStatus.OK, ret - def _parse_parameter_deactivated(self, request: SynapseRequest) -> Optional[bool]: + def _parse_parameter_deactivated(self, request: SynapseRequest) -> bool | None: """ Return None (no filtering) if `deactivated` is `true`, otherwise return `False` (exclude deactivated users from the results). @@ -206,9 +206,7 @@ class UsersRestServletV2(RestServlet): class UsersRestServletV3(UsersRestServletV2): PATTERNS = admin_patterns("/users$", "v3") - def _parse_parameter_deactivated( - self, request: SynapseRequest - ) -> Union[bool, None]: + def _parse_parameter_deactivated(self, request: SynapseRequest) -> bool | None: return parse_boolean(request, "deactivated") @@ -340,7 +338,7 @@ class UserRestServletV2(RestServlet): HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked" ) - approved: Optional[bool] = None + approved: bool | None = None if "approved" in body and self._msc3866_enabled: approved = body["approved"] if not isinstance(approved, bool): @@ -920,7 +918,7 @@ class SearchUsersRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, target_user_id: str - ) -> tuple[int, Optional[list[JsonDict]]]: + ) -> tuple[int, list[JsonDict] | None]: """Get request to search user table for specific users according to search term. This needs user to have a administrator access in Synapse. @@ -1476,9 +1474,9 @@ class RedactUser(RestServlet): class PostBody(RequestBodyModel): rooms: list[StrictStr] - reason: Optional[StrictStr] = None - limit: Optional[StrictInt] = None - use_admin: Optional[StrictBool] = None + reason: StrictStr | None = None + limit: StrictInt | None = None + use_admin: StrictBool | None = None async def on_POST( self, request: SynapseRequest, user_id: str diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index f928a8a3f4..b052052be0 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -21,7 +21,7 @@ # import logging import random -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal from urllib.parse import urlparse import attr @@ -161,11 +161,11 @@ class PasswordRestServlet(RestServlet): self._set_password_handler = hs.get_set_password_handler() class PostBody(RequestBodyModel): - auth: Optional[AuthenticationData] = None + auth: AuthenticationData | None = None logout_devices: StrictBool = True - new_password: Optional[ - Annotated[str, StringConstraints(max_length=512, strict=True)] - ] = None + new_password: ( + Annotated[str, StringConstraints(max_length=512, strict=True)] | None + ) = None @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]: @@ -259,7 +259,7 @@ class PasswordRestServlet(RestServlet): # If we have a password in this request, prefer it. Otherwise, use the # password hash from an earlier request. if new_password: - password_hash: Optional[str] = await self.auth_handler.hash(new_password) + password_hash: str | None = await self.auth_handler.hash(new_password) elif session_id is not None: password_hash = existing_session_password_hash else: @@ -289,8 +289,8 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() class PostBody(RequestBodyModel): - auth: Optional[AuthenticationData] = None - id_server: Optional[StrictStr] = None + auth: AuthenticationData | None = None + id_server: StrictStr | None = None # Not specced, see https://github.com/matrix-org/matrix-spec/issues/297 erase: StrictBool = False @@ -663,7 +663,7 @@ class ThreepidAddRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): - auth: Optional[AuthenticationData] = None + auth: AuthenticationData | None = None client_secret: ClientSecretStr sid: StrictStr @@ -742,7 +742,7 @@ class ThreepidUnbindRestServlet(RestServlet): class PostBody(RequestBodyModel): address: StrictStr - id_server: Optional[StrictStr] = None + id_server: StrictStr | None = None medium: Literal["email", "msisdn"] async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]: @@ -771,7 +771,7 @@ class ThreepidDeleteRestServlet(RestServlet): class PostBody(RequestBodyModel): address: StrictStr - id_server: Optional[StrictStr] = None + id_server: StrictStr | None = None medium: Literal["email", "msisdn"] async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]: diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 0800c0f5b8..b18232fc56 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import AccountDataTypes, ReceiptTypes from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError @@ -108,9 +108,9 @@ class AccountDataServlet(RestServlet): # Push rules are stored in a separate table and must be queried separately. if account_data_type == AccountDataTypes.PUSH_RULES: - account_data: Optional[ - JsonMapping - ] = await self._push_rules_handler.push_rules_for_user(requester.user) + account_data: ( + JsonMapping | None + ) = await self._push_rules_handler.push_rules_for_user(requester.user) else: account_data = await self.store.get_global_account_data_by_type_for_user( user_id, account_data_type @@ -244,7 +244,7 @@ class RoomAccountDataServlet(RestServlet): # Room-specific push rules are not currently supported. if account_data_type == AccountDataTypes.PUSH_RULES: - account_data: Optional[JsonMapping] = {} + account_data: JsonMapping | None = {} else: account_data = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index e20e49d48b..636e4b6031 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -22,7 +22,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pydantic import ConfigDict, StrictStr @@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): - auth: Optional[AuthenticationData] = None + auth: AuthenticationData | None = None devices: list[StrictStr] @interactive_auth_handler @@ -173,7 +173,7 @@ class DeviceRestServlet(RestServlet): return 200, device class DeleteBody(RequestBodyModel): - auth: Optional[AuthenticationData] = None + auth: AuthenticationData | None = None @interactive_auth_handler async def on_DELETE( @@ -218,7 +218,7 @@ class DeviceRestServlet(RestServlet): return 200, {} class PutBody(RequestBodyModel): - display_name: Optional[StrictStr] = None + display_name: StrictStr | None = None async def on_PUT( self, request: SynapseRequest, device_id: str @@ -316,7 +316,7 @@ class DehydratedDeviceServlet(RestServlet): class PutBody(RequestBodyModel): device_data: DehydratedDeviceDataModel - initial_device_display_name: Optional[StrictStr] = None + initial_device_display_name: StrictStr | None = None async def on_PUT(self, request: SynapseRequest) -> tuple[int, JsonDict]: submission = parse_and_validate_json_object_from_request(request, self.PutBody) @@ -391,7 +391,7 @@ class DehydratedDeviceEventsServlet(RestServlet): self.store = hs.get_datastores().main class PostBody(RequestBodyModel): - next_batch: Optional[StrictStr] = None + next_batch: StrictStr | None = None async def on_POST( self, request: SynapseRequest, device_id: str @@ -538,7 +538,7 @@ class DehydratedDeviceV2Servlet(RestServlet): class PutBody(RequestBodyModel): device_data: DehydratedDeviceDataModel device_id: StrictStr - initial_device_display_name: Optional[StrictStr] + initial_device_display_name: StrictStr | None model_config = ConfigDict(extra="allow") async def on_PUT(self, request: SynapseRequest) -> tuple[int, JsonDict]: diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index 943674bbb1..0b334f9b0b 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal from pydantic import StrictStr @@ -73,7 +73,7 @@ class ClientDirectoryServer(RestServlet): # TODO: get Pydantic to validate that this is a valid room id? room_id: StrictStr # `servers` is unspecced - servers: Optional[list[StrictStr]] = None + servers: list[StrictStr] | None = None async def on_PUT( self, request: SynapseRequest, room_alias: str diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 082bacade6..de73c96fd0 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -22,7 +22,7 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from synapse.api.errors import SynapseError from synapse.events.utils import SerializeEventConfig @@ -96,7 +96,7 @@ class EventRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, event_id: str - ) -> tuple[int, Union[str, JsonDict]]: + ) -> tuple[int, str | JsonDict]: requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index b87b9bd68a..5f488674b4 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -24,7 +24,7 @@ import logging import re from collections import Counter from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Mapping from pydantic import StrictBool, StrictStr, field_validator @@ -147,7 +147,7 @@ class KeyUploadServlet(RestServlet): key: StrictStr """The key, encoded using unpadded base64.""" - fallback: Optional[StrictBool] = False + fallback: StrictBool | None = False """Whether this is a fallback key. Only used when handling fallback keys.""" signatures: Mapping[StrictStr, Mapping[StrictStr, StrictStr]] @@ -156,10 +156,10 @@ class KeyUploadServlet(RestServlet): See the following for more detail: https://spec.matrix.org/v1.16/appendices/#signing-details """ - device_keys: Optional[DeviceKeys] = None + device_keys: DeviceKeys | None = None """Identity keys for the device. May be absent if no new identity keys are required.""" - fallback_keys: Optional[Mapping[StrictStr, Union[StrictStr, KeyObject]]] = None + fallback_keys: Mapping[StrictStr, StrictStr | KeyObject] | None = None """ The public key which should be used if the device's one-time keys are exhausted. The fallback key is not deleted once used, but should be @@ -193,7 +193,7 @@ class KeyUploadServlet(RestServlet): ) return v - one_time_keys: Optional[Mapping[StrictStr, Union[StrictStr, KeyObject]]] = None + one_time_keys: Mapping[StrictStr, StrictStr | KeyObject] | None = None """ One-time public keys for "pre-key" messages. The names of the properties should be in the format `:`. @@ -221,7 +221,7 @@ class KeyUploadServlet(RestServlet): return v async def on_POST( - self, request: SynapseRequest, device_id: Optional[str] + self, request: SynapseRequest, device_id: str | None ) -> tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 5e96079b66..cd3afda11e 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -69,7 +69,7 @@ class KnockRoomAliasServlet(RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + # twisted.web.server.Request.args is incorrectly defined as Any | None args: dict[bytes, list[bytes]] = request.args # type: ignore # Prefer via over server_name (deprecated with MSC4156) remote_room_hosts = parse_strings_from_args(args, "via", required=False) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index bba6944982..fe3cb9aa3d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -26,9 +26,7 @@ from typing import ( Any, Awaitable, Callable, - Optional, TypedDict, - Union, ) from synapse.api.constants import ApprovalNoticeMedium @@ -67,12 +65,12 @@ logger = logging.getLogger(__name__) class LoginResponse(TypedDict, total=False): user_id: str - access_token: Optional[str] + access_token: str | None home_server: str - expires_in_ms: Optional[int] - refresh_token: Optional[str] - device_id: Optional[str] - well_known: Optional[dict[str, Any]] + expires_in_ms: int | None + refresh_token: str | None + device_id: str | None + well_known: dict[str, Any] | None class LoginRestServlet(RestServlet): @@ -367,13 +365,13 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, + callback: Callable[[LoginResponse], Awaitable[None]] | None = None, create_non_existent_users: bool = False, - default_display_name: Optional[str] = None, + default_display_name: str | None = None, ratelimit: bool = True, - auth_provider_id: Optional[str] = None, + auth_provider_id: str | None = None, should_issue_refresh_token: bool = False, - auth_provider_session_id: Optional[str] = None, + auth_provider_session_id: str | None = None, should_check_deactivated_or_locked: bool = True, *, request_info: RequestInfo, @@ -623,7 +621,7 @@ class RefreshTokenServlet(RestServlet): token, access_valid_until_ms, refresh_valid_until_ms ) - response: dict[str, Union[str, int]] = { + response: dict[str, str | int] = { "access_token": access_token, "refresh_token": refresh_token, } @@ -652,9 +650,7 @@ class SsoRedirectServlet(RestServlet): self._sso_handler = hs.get_sso_handler() self._public_baseurl = hs.config.server.public_baseurl - async def on_GET( - self, request: SynapseRequest, idp_id: Optional[str] = None - ) -> None: + async def on_GET(self, request: SynapseRequest, idp_id: str | None = None) -> None: if not self._public_baseurl: raise SynapseError(400, "SSO requires a valid public_baseurl") diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 4c044ae900..f145b03af4 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -22,7 +22,6 @@ import logging import re -from typing import Optional from synapse.http.server import ( HttpServer, @@ -231,7 +230,7 @@ class DownloadResource(RestServlet): request: SynapseRequest, server_name: str, media_id: str, - file_name: Optional[str] = None, + file_name: str | None = None, ) -> None: # Validate the server name, raising if invalid parse_and_validate_server_name(server_name) diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py index 7d0570d0cb..bda6ed1f70 100644 --- a/synapse/rest/client/mutual_rooms.py +++ b/synapse/rest/client/mutual_rooms.py @@ -52,7 +52,7 @@ class UserMutualRoomsServlet(RestServlet): self.store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + # twisted.web.server.Request.args is incorrectly defined as Any | None args: dict[bytes, list[bytes]] = request.args # type: ignore user_ids = parse_strings_from_args(args, "user_id", required=True) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 0a9b83af95..39b0cde47b 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -20,7 +20,7 @@ # from http import HTTPStatus -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from synapse.api.errors import ( Codes, @@ -240,7 +240,7 @@ def _rule_spec_from_path(path: list[str]) -> RuleSpec: def _rule_tuple_from_request_object( rule_template: str, rule_id: str, req_obj: JsonDict -) -> tuple[list[JsonDict], list[Union[str, JsonDict]]]: +) -> tuple[list[JsonDict], list[str | JsonDict]]: if rule_template == "postcontent": # postcontent is from MSC4306, which says that clients # cannot create their own postcontent rules right now. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 145dc6f569..9503446b92 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -21,7 +21,7 @@ # import logging import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.web.server import Request @@ -852,7 +852,7 @@ class RegisterRestServlet(RestServlet): return result async def _do_guest_registration( - self, params: JsonDict, address: Optional[str] = None + self, params: JsonDict, address: str | None = None ) -> tuple[int, JsonDict]: if not self.hs.config.registration.allow_guest_access: raise SynapseError(403, "Guest access is disabled") diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d6c7411816..c913bc6970 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -20,7 +20,7 @@ import logging import re -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import Direction from synapse.handlers.relations import ThreadsListInclude @@ -61,8 +61,8 @@ class RelationPaginationServlet(RestServlet): request: SynapseRequest, room_id: str, parent_id: str, - relation_type: Optional[str] = None, - event_type: Optional[str] = None, + relation_type: str | None = None, + event_type: str | None = None, ) -> tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py index a1808847f0..08a449eefc 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py @@ -21,7 +21,7 @@ import logging from http.client import TEMPORARY_REDIRECT -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.http.server import HttpServer, respond_with_redirect from synapse.http.servlet import RestServlet @@ -41,7 +41,7 @@ class MSC4108DelegationRendezvousServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - redirection_target: Optional[str] = ( + redirection_target: str | None = ( hs.config.experimental.msc4108_delegation_endpoint ) assert redirection_target is not None, ( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 38e315d0e7..81a6bd57fc 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -25,7 +25,7 @@ import logging import re from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Optional +from typing import TYPE_CHECKING, Awaitable from urllib import parse as urlparse from prometheus_client.core import Histogram @@ -294,7 +294,7 @@ class RoomStateEventRestServlet(RestServlet): room_id: str, event_type: str, state_key: str, - txn_id: Optional[str] = None, + txn_id: str | None = None, ) -> tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -407,7 +407,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): requester: Requester, room_id: str, event_type: str, - txn_id: Optional[str], + txn_id: str | None, ) -> tuple[int, JsonDict]: content = parse_json_object_from_request(request) @@ -484,8 +484,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): def _parse_request_delay( request: SynapseRequest, - max_delay: Optional[int], -) -> Optional[int]: + max_delay: int | None, +) -> int | None: """Parses from the request string the delay parameter for delayed event requests, and checks it for correctness. @@ -544,11 +544,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): request: SynapseRequest, requester: Requester, room_identifier: str, - txn_id: Optional[str], + txn_id: str | None, ) -> tuple[int, JsonDict]: content = parse_json_object_from_request(request, allow_empty_body=True) - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + # twisted.web.server.Request.args is incorrectly defined as Any | None args: dict[bytes, list[bytes]] = request.args # type: ignore # Prefer via over server_name (deprecated with MSC4156) remote_room_hosts = parse_strings_from_args(args, "via", required=False) @@ -623,7 +623,7 @@ class PublicRoomListRestServlet(RestServlet): if server: raise e - limit: Optional[int] = parse_integer(request, "limit", 0) + limit: int | None = parse_integer(request, "limit", 0) since_token = parse_string(request, "since") if limit == 0: @@ -658,7 +658,7 @@ class PublicRoomListRestServlet(RestServlet): server = parse_string(request, "server") content = parse_json_object_from_request(request) - limit: Optional[int] = int(content.get("limit", 100)) + limit: int | None = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -1118,7 +1118,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): requester: Requester, room_id: str, membership_action: str, - txn_id: Optional[str], + txn_id: str | None, ) -> tuple[int, JsonDict]: if requester.is_guest and membership_action not in { Membership.JOIN, @@ -1241,7 +1241,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester: Requester, room_id: str, event_id: str, - txn_id: Optional[str], + txn_id: str | None, ) -> tuple[int, JsonDict]: content = parse_json_object_from_request(request) @@ -1572,7 +1572,7 @@ class RoomHierarchyRestServlet(RestServlet): max_depth = parse_integer(request, "max_depth") limit = parse_integer(request, "limit") - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + # twisted.web.server.Request.args is incorrectly defined as Any | None remote_room_hosts = None if self.msc4235_enabled: args: dict[bytes, list[bytes]] = request.args # type: ignore @@ -1617,12 +1617,12 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): ) -> tuple[int, JsonDict]: try: requester = await self._auth.get_user_by_req(request, allow_guest=True) - requester_user_id: Optional[str] = requester.user.to_string() + requester_user_id: str | None = requester.user.to_string() except MissingClientTokenError: # auth is optional requester_user_id = None - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + # twisted.web.server.Request.args is incorrectly defined as Any | None args: dict[bytes, list[bytes]] = request.args # type: ignore remote_room_hosts = parse_strings_from_args(args, "via", required=False) room_id, remote_room_hosts = await self.resolve_room_id( diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index b2de591dc5..b7f7c68d8f 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -19,7 +19,7 @@ # import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer @@ -51,7 +51,7 @@ class RoomKeysServlet(RestServlet): self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() async def on_PUT( - self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + self, request: SynapseRequest, room_id: str | None, session_id: str | None ) -> tuple[int, JsonDict]: """ Uploads one or more encrypted E2E room keys for backup purposes. @@ -146,7 +146,7 @@ class RoomKeysServlet(RestServlet): return 200, ret async def on_GET( - self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + self, request: SynapseRequest, room_id: str | None, session_id: str | None ) -> tuple[int, JsonDict]: """ Retrieves one or more encrypted E2E room keys for backup purposes. @@ -233,7 +233,7 @@ class RoomKeysServlet(RestServlet): return 200, room_keys async def on_DELETE( - self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + self, request: SynapseRequest, room_id: str | None, session_id: str | None ) -> tuple[int, JsonDict]: """ Deletes one or more encrypted E2E room keys for a user for backup purposes. diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 9c03eecea4..458bf08a19 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -21,7 +21,7 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Mapping import attr @@ -189,7 +189,7 @@ class SyncRestServlet(RestServlet): # in the response cache once the set of ignored users has changed. # (We filter out ignored users from timeline events, so our sync response # is invalid once the set of ignored users changes.) - last_ignore_accdata_streampos: Optional[int] = None + last_ignore_accdata_streampos: int | None = None if not since: # No `since`, so this is an initial sync. last_ignore_accdata_streampos = await self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( @@ -547,7 +547,7 @@ class SyncRestServlet(RestServlet): async def encode_room( self, sync_config: SyncConfig, - room: Union[JoinedSyncResult, ArchivedSyncResult], + room: JoinedSyncResult | ArchivedSyncResult, time_now: int, joined: bool, serialize_options: SerializeEventConfig, diff --git a/synapse/rest/client/thread_subscriptions.py b/synapse/rest/client/thread_subscriptions.py index d02f2cb48a..60676a4032 100644 --- a/synapse/rest/client/thread_subscriptions.py +++ b/synapse/rest/client/thread_subscriptions.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import attr from typing_extensions import TypeAlias @@ -50,7 +50,7 @@ class ThreadSubscriptionsRestServlet(RestServlet): self.handler = hs.get_thread_subscriptions_handler() class PutBody(RequestBodyModel): - automatic: Optional[AnyEventId] = None + automatic: AnyEventId | None = None """ If supplied, the event ID of an event giving rise to this automatic subscription. diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index f783acdb83..41e49ac384 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -21,7 +21,7 @@ import logging import re -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 @@ -107,7 +107,7 @@ class LocalKey(RestServlet): return json_object def on_GET( - self, request: Request, key_id: Optional[str] = None + self, request: Request, key_id: str | None = None ) -> tuple[int, JsonDict]: # Matrix 1.6 drops support for passing the key_id, this is incompatible # with earlier versions and is allowed in order to support both. diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e8b0b31210..c3dc69889b 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -21,7 +21,7 @@ import logging import re -from typing import TYPE_CHECKING, Mapping, Optional +from typing import TYPE_CHECKING, Mapping from pydantic import ConfigDict, StrictInt, StrictStr from signedjson.sign import sign_json @@ -50,7 +50,7 @@ logger = logging.getLogger(__name__) class _KeyQueryCriteriaDataModel(RequestBodyModel): model_config = ConfigDict(extra="allow") - minimum_valid_until_ts: Optional[StrictInt] + minimum_valid_until_ts: StrictInt | None class RemoteKey(RestServlet): @@ -142,7 +142,7 @@ class RemoteKey(RestServlet): ) async def on_GET( - self, request: Request, server: str, key_id: Optional[str] = None + self, request: Request, server: str, key_id: str | None = None ) -> tuple[int, JsonDict]: if server and key_id: # Matrix 1.6 drops support for passing the key_id, this is incompatible @@ -181,11 +181,11 @@ class RemoteKey(RestServlet): ) -> JsonDict: logger.info("Handling query for keys %r", query) - server_keys: dict[tuple[str, str], Optional[FetchKeyResultForRemote]] = {} + server_keys: dict[tuple[str, str], FetchKeyResultForRemote | None] = {} for server_name, key_ids in query.items(): if key_ids: results: Mapping[ - str, Optional[FetchKeyResultForRemote] + str, FetchKeyResultForRemote | None ] = await self.store.get_server_keys_json_for_remote( server_name, key_ids ) diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py index 3c3f703667..f4569cfc7e 100644 --- a/synapse/rest/media/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -21,7 +21,7 @@ # import logging import re -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.http.server import set_corp_headers, set_cors_headers from synapse.http.servlet import RestServlet, parse_boolean, parse_integer @@ -57,7 +57,7 @@ class DownloadResource(RestServlet): request: SynapseRequest, server_name: str, media_id: str, - file_name: Optional[str] = None, + file_name: str | None = None, ) -> None: # Validate the server name, raising if invalid parse_and_validate_server_name(server_name) diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 484749dbe6..56bc727cf8 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -22,7 +22,7 @@ import logging import re -from typing import IO, TYPE_CHECKING, Optional +from typing import IO, TYPE_CHECKING from synapse.api.errors import Codes, SynapseError from synapse.http.server import respond_with_json @@ -56,7 +56,7 @@ class BaseUploadServlet(RestServlet): async def _get_file_metadata( self, request: SynapseRequest, user_id: str - ) -> tuple[int, Optional[str], str]: + ) -> tuple[int, str | None, str]: raw_content_length = request.getHeader("Content-Length") if raw_content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) @@ -82,7 +82,7 @@ class BaseUploadServlet(RestServlet): upload_name_bytes = parse_bytes_from_args(args, "filename") if upload_name_bytes: try: - upload_name: Optional[str] = upload_name_bytes.decode("utf8") + upload_name: str | None = upload_name_bytes.decode("utf8") except UnicodeDecodeError: raise SynapseError( msg="Invalid UTF-8 filename parameter: %r" % (upload_name_bytes,), diff --git a/synapse/rest/synapse/mas/devices.py b/synapse/rest/synapse/mas/devices.py index eac51de44c..9d94a67675 100644 --- a/synapse/rest/synapse/mas/devices.py +++ b/synapse/rest/synapse/mas/devices.py @@ -15,7 +15,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pydantic import StrictStr @@ -53,7 +53,7 @@ class MasUpsertDeviceResource(MasBaseResource): class PostBody(RequestBodyModel): localpart: StrictStr device_id: StrictStr - display_name: Optional[StrictStr] = None + display_name: StrictStr | None = None async def _async_render_POST( self, request: "SynapseRequest" diff --git a/synapse/rest/synapse/mas/users.py b/synapse/rest/synapse/mas/users.py index f52c4bb167..55c7337555 100644 --- a/synapse/rest/synapse/mas/users.py +++ b/synapse/rest/synapse/mas/users.py @@ -15,7 +15,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Optional, TypedDict +from typing import TYPE_CHECKING, Any, TypedDict from pydantic import StrictBool, StrictStr, model_validator @@ -52,8 +52,8 @@ class MasQueryUserResource(MasBaseResource): class Response(TypedDict): user_id: str - display_name: Optional[str] - avatar_url: Optional[str] + display_name: str | None + avatar_url: str | None is_suspended: bool is_deactivated: bool @@ -65,7 +65,7 @@ class MasQueryUserResource(MasBaseResource): localpart = parse_string(request, "localpart", required=True) user_id = UserID(localpart, self.hostname) - user: Optional[UserInfo] = await self.store.get_user_by_id(user_id=str(user_id)) + user: UserInfo | None = await self.store.get_user_by_id(user_id=str(user_id)) if user is None: raise NotFoundError("User not found") @@ -104,13 +104,13 @@ class MasProvisionUserResource(MasBaseResource): localpart: StrictStr unset_displayname: StrictBool = False - set_displayname: Optional[StrictStr] = None + set_displayname: StrictStr | None = None unset_avatar_url: StrictBool = False - set_avatar_url: Optional[StrictStr] = None + set_avatar_url: StrictStr | None = None unset_emails: StrictBool = False - set_emails: Optional[list[StrictStr]] = None + set_emails: list[StrictStr] | None = None @model_validator(mode="before") @classmethod @@ -165,7 +165,7 @@ class MasProvisionUserResource(MasBaseResource): by_admin=True, ) - new_email_list: Optional[set[str]] = None + new_email_list: set[str] | None = None if body.unset_emails: new_email_list = set() elif body.set_emails is not None: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 00965cfb82..801d474ecc 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,7 +18,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from twisted.web.resource import Resource from twisted.web.server import Request @@ -42,7 +42,7 @@ class WellKnownBuilder: self._config = hs.config self._auth = hs.get_auth() - async def get_well_known(self) -> Optional[JsonDict]: + async def get_well_known(self) -> JsonDict | None: if not self._config.server.serve_client_wellknown: return None diff --git a/synapse/server.py b/synapse/server.py index 766515c930..de0a2b098c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -34,7 +34,6 @@ from typing import ( Any, Awaitable, Callable, - Optional, TypeVar, cast, ) @@ -320,7 +319,7 @@ class HomeServer(metaclass=abc.ABCMeta): self, hostname: str, config: HomeServerConfig, - reactor: Optional[ISynapseReactor] = None, + reactor: ISynapseReactor | None = None, ): """ Args: @@ -340,33 +339,33 @@ class HomeServer(metaclass=abc.ABCMeta): self.config = config self._listening_services: list[Port] = [] self._metrics_listeners: list[tuple[WSGIServer, Thread]] = [] - self.start_time: Optional[int] = None + self.start_time: int | None = None self._instance_id = random_string(5) self._instance_name = config.worker.instance_name self.version_string = f"Synapse/{SYNAPSE_VERSION}" - self.datastores: Optional[Databases] = None + self.datastores: Databases | None = None self._module_web_resources: dict[str, Resource] = {} self._module_web_resources_consumed = False # This attribute is set by the free function `refresh_certificate`. - self.tls_server_context_factory: Optional[IOpenSSLContextFactory] = None + self.tls_server_context_factory: IOpenSSLContextFactory | None = None self._is_shutdown = False self._async_shutdown_handlers: list[ShutdownInfo] = [] self._sync_shutdown_handlers: list[ShutdownInfo] = [] - self._background_processes: set[defer.Deferred[Optional[Any]]] = set() + self._background_processes: set[defer.Deferred[Any | None]] = set() def run_as_background_process( self, desc: "LiteralString", - func: Callable[..., Awaitable[Optional[R]]], + func: Callable[..., Awaitable[R | None]], *args: Any, **kwargs: Any, - ) -> "defer.Deferred[Optional[R]]": + ) -> "defer.Deferred[R | None]": """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 73cf4091eb..b4e512618e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -18,7 +18,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.events import EventBase @@ -59,8 +59,8 @@ class ServerNoticesManager: user_id: str, event_content: dict, type: str = EventTypes.Message, - state_key: Optional[str] = None, - txn_id: Optional[str] = None, + state_key: str | None = None, + txn_id: str | None = None, ) -> EventBase: """Send a notice to the given user @@ -99,7 +99,7 @@ class ServerNoticesManager: return event @cached() - async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]: + async def maybe_get_notice_room_for_user(self, user_id: str) -> str | None: """Try to look up the server notice room for this user if it exists. Does not create one if none can be found. @@ -294,8 +294,8 @@ class ServerNoticesManager: self, requester: Requester, room_id: str, - display_name: Optional[str], - avatar_url: Optional[str], + display_name: str | None, + avatar_url: str | None, ) -> None: """ Updates the notice user's profile if it's different from what is in the room. @@ -341,7 +341,7 @@ class ServerNoticesManager: room_id: str, info_event_type: str, info_content_key: str, - info_value: Optional[str], + info_value: str | None, ) -> None: """ Updates a specific notice room's info if it's different from what is set. diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index bc62d6ac6c..fd4f36f5c8 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -17,7 +17,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Iterable, Union +from typing import TYPE_CHECKING, Iterable from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( @@ -39,7 +39,7 @@ class ServerNoticesSender(WorkerServerNoticesSender): def __init__(self, hs: "HomeServer"): super().__init__(hs) self._server_notices: Iterable[ - Union[ConsentServerNotices, ResourceLimitsServerNotices] + ConsentServerNotices | ResourceLimitsServerNotices ] = ( ConsentServerNotices(hs), ResourceLimitsServerNotices(hs), diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 991e1f847a..9fc49be4b1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -96,10 +96,10 @@ class _StateCacheEntry: def __init__( self, - state: Optional[StateMap[str]], - state_group: Optional[int], - prev_group: Optional[int] = None, - delta_ids: Optional[StateMap[str]] = None, + state: StateMap[str] | None, + state_group: int | None, + prev_group: int | None = None, + delta_ids: StateMap[str] | None = None, ): if state is None and state_group is None and prev_group is None: raise Exception("One of state, state_group or prev_group must be not None") @@ -111,7 +111,7 @@ class _StateCacheEntry: # # This can be None if we have a `state_group` (as then we can fetch the # state from the DB.) - self._state: Optional[StateMap[str]] = ( + self._state: StateMap[str] | None = ( immutabledict(state) if state is not None else None ) @@ -120,7 +120,7 @@ class _StateCacheEntry: self.state_group = state_group self.prev_group = prev_group - self.delta_ids: Optional[StateMap[str]] = ( + self.delta_ids: StateMap[str] | None = ( immutabledict(delta_ids) if delta_ids is not None else None ) @@ -206,7 +206,7 @@ class StateHandler: self, room_id: str, event_ids: StrCollection, - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, ) -> StateMap[str]: """Fetch the state after each of the given event IDs. Resolve them and return. @@ -283,9 +283,9 @@ class StateHandler: async def calculate_context_info( self, event: EventBase, - state_ids_before_event: Optional[StateMap[str]] = None, - partial_state: Optional[bool] = None, - state_group_before_event: Optional[int] = None, + state_ids_before_event: StateMap[str] | None = None, + partial_state: bool | None = None, + state_group_before_event: int | None = None, ) -> UnpersistedEventContextBase: """ Calulates the contents of an unpersisted event context, other than the current @@ -456,8 +456,8 @@ class StateHandler: async def compute_event_context( self, event: EventBase, - state_ids_before_event: Optional[StateMap[str]] = None, - partial_state: Optional[bool] = None, + state_ids_before_event: StateMap[str] | None = None, + partial_state: bool | None = None, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -670,7 +670,7 @@ class StateResolutionHandler: room_id: str, room_version: str, state_groups_ids: Mapping[int, StateMap[str]], - event_map: Optional[dict[str, EventBase]], + event_map: dict[str, EventBase] | None, state_res_store: "StateResolutionStore", ) -> _StateCacheEntry: """Resolves conflicts between a set of state groups @@ -770,7 +770,7 @@ class StateResolutionHandler: room_id: str, room_version: str, state_sets: Sequence[StateMap[str]], - event_map: Optional[dict[str, EventBase]], + event_map: dict[str, EventBase] | None, state_res_store: "StateResolutionStore", ) -> StateMap[str]: """ @@ -934,7 +934,7 @@ def _make_state_cache_entry( # failing that, look for the closest match. prev_group = None - delta_ids: Optional[StateMap[str]] = None + delta_ids: StateMap[str] | None = None for old_group, old_state in state_groups_ids.items(): if old_state.keys() - new_state.keys(): @@ -991,8 +991,8 @@ class StateResolutionStore: self, room_id: str, state_sets: list[set[str]], - conflicted_state: Optional[set[str]], - additional_backwards_reachable_conflicted_events: Optional[set[str]], + conflicted_state: set[str] | None, + additional_backwards_reachable_conflicted_events: set[str] | None, ) -> Awaitable[StateDifference]: """ "Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a219347264..0b4514d322 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -24,7 +24,6 @@ from typing import ( Awaitable, Callable, Iterable, - Optional, Sequence, ) @@ -45,7 +44,7 @@ async def resolve_events_with_store( room_id: str, room_version: RoomVersion, state_sets: Sequence[StateMap[str]], - event_map: Optional[dict[str, EventBase]], + event_map: dict[str, EventBase] | None, state_map_factory: Callable[[StrCollection], Awaitable[dict[str, EventBase]]], ) -> StateMap[str]: """ diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 683f0c1dcc..c410c3a7ec 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -28,7 +28,6 @@ from typing import ( Generator, Iterable, Literal, - Optional, Protocol, Sequence, overload, @@ -63,8 +62,8 @@ class StateResolutionStore(Protocol): self, room_id: str, state_sets: list[set[str]], - conflicted_state: Optional[set[str]], - additional_backwards_reachable_conflicted_events: Optional[set[str]], + conflicted_state: set[str] | None, + additional_backwards_reachable_conflicted_events: set[str] | None, ) -> Awaitable[StateDifference]: ... @@ -84,7 +83,7 @@ async def resolve_events_with_store( room_id: str, room_version: RoomVersion, state_sets: Sequence[StateMap[str]], - event_map: Optional[dict[str, EventBase]], + event_map: dict[str, EventBase] | None, state_res_store: StateResolutionStore, ) -> StateMap[str]: """Resolves the state using the v2 state resolution algorithm @@ -124,7 +123,7 @@ async def resolve_events_with_store( logger.debug("%d conflicted state entries", len(conflicted_state)) logger.debug("Calculating auth chain difference") - conflicted_set: Optional[set[str]] = None + conflicted_set: set[str] | None = None if room_version.state_res == StateResolutionVersions.V2_1: # calculate the conflicted subgraph conflicted_set = set(itertools.chain.from_iterable(conflicted_state.values())) @@ -313,7 +312,7 @@ async def _get_auth_chain_difference( state_sets: Sequence[StateMap[str]], unpersisted_events: dict[str, EventBase], state_res_store: StateResolutionStore, - conflicted_state: Optional[set[str]], + conflicted_state: set[str] | None, ) -> set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some, but not all of the auth chains. @@ -546,7 +545,7 @@ def _seperate( conflicted_state[key] = event_ids # mypy doesn't understand that discarding None above means that conflicted - # state is StateMap[set[str]], not StateMap[set[Optional[Str]]]. + # state is StateMap[set[str]], not StateMap[set[str | None]]. return unconflicted_state, conflicted_state # type: ignore[return-value] @@ -755,7 +754,7 @@ async def _mainline_sort( clock: Clock, room_id: str, event_ids: list[str], - resolved_power_event_id: Optional[str], + resolved_power_event_id: str | None, event_map: dict[str, EventBase], state_res_store: StateResolutionStore, ) -> list[str]: @@ -842,7 +841,7 @@ async def _get_mainline_depth_for_event( """ room_id = event.room_id - tmp_event: Optional[EventBase] = event + tmp_event: EventBase | None = event # We do an iterative search, replacing `event with the power level in its # auth events (if any) @@ -889,7 +888,7 @@ async def _get_event( event_map: dict[str, EventBase], state_res_store: StateResolutionStore, allow_none: Literal[True], -) -> Optional[EventBase]: ... +) -> EventBase | None: ... async def _get_event( @@ -898,7 +897,7 @@ async def _get_event( event_map: dict[str, EventBase], state_res_store: StateResolutionStore, allow_none: bool = False, -) -> Optional[EventBase]: +) -> EventBase | None: """Helper function to look up event in event_map, falling back to looking it up in the store diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b6958ef06b..8eeea20967 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -21,7 +21,7 @@ # import logging from abc import ABCMeta -from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Collection, Iterable from synapse.storage.database import ( DatabasePool, @@ -176,7 +176,7 @@ class SQLBaseStore(metaclass=ABCMeta): ) def _attempt_to_invalidate_cache( - self, cache_name: str, key: Optional[Collection[Any]] + self, cache_name: str, key: Collection[Any] | None ) -> bool: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, @@ -218,7 +218,7 @@ class SQLBaseStore(metaclass=ABCMeta): self.external_cached_functions[cache_name] = func -def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: +def db_to_json(db_content: memoryview | bytes | bytearray | str) -> Any: """ Take some data from a database row and return a JSON-decoded object. diff --git a/synapse/storage/admin_client_config.py b/synapse/storage/admin_client_config.py index 07acddc660..b56e21edfa 100644 --- a/synapse/storage/admin_client_config.py +++ b/synapse/storage/admin_client_config.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from synapse.types import JsonMapping @@ -9,7 +8,7 @@ logger = logging.getLogger(__name__) class AdminClientConfig: """Class to track various Synapse-specific admin-only client-impacting config options.""" - def __init__(self, account_data: Optional[JsonMapping]): + def __init__(self, account_data: JsonMapping | None): # Allow soft-failed events to be returned down `/sync` and other # client APIs. `io.element.synapse.soft_failed: true` is added to the # `unsigned` portion of the event to inform clients that the event diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 1c17d4d609..c71bcdb7fb 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -29,7 +29,6 @@ from typing import ( Awaitable, Callable, Iterable, - Optional, Sequence, cast, ) @@ -169,9 +168,9 @@ class _BackgroundUpdateContextManager: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: pass @@ -196,7 +195,7 @@ class BackgroundUpdatePerformance: self.avg_item_count += 0.1 * (item_count - self.avg_item_count) self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms) - def average_items_per_ms(self) -> Optional[float]: + def average_items_per_ms(self) -> float | None: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -212,7 +211,7 @@ class BackgroundUpdatePerformance: # changes in how long the update process takes. return float(self.avg_item_count) / float(self.avg_duration_ms) - def total_items_per_ms(self) -> Optional[float]: + def total_items_per_ms(self) -> float | None: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -250,11 +249,11 @@ class BackgroundUpdater: self._database_name = database.name() # if a background update is currently running, its name. - self._current_background_update: Optional[str] = None + self._current_background_update: str | None = None - self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None - self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None - self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None + self._on_update_callback: ON_UPDATE_CALLBACK | None = None + self._default_batch_size_callback: DEFAULT_BATCH_SIZE_CALLBACK | None = None + self._min_batch_size_callback: MIN_BATCH_SIZE_CALLBACK | None = None self._background_update_performance: dict[str, BackgroundUpdatePerformance] = {} self._background_update_handlers: dict[str, _BackgroundUpdateHandler] = {} @@ -304,8 +303,8 @@ class BackgroundUpdater: def register_update_controller_callbacks( self, on_update: ON_UPDATE_CALLBACK, - default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, - min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, + default_batch_size: DEFAULT_BATCH_SIZE_CALLBACK | None = None, + min_batch_size: DEFAULT_BATCH_SIZE_CALLBACK | None = None, ) -> None: """Register callbacks from a module for each hook.""" if self._on_update_callback is not None: @@ -380,7 +379,7 @@ class BackgroundUpdater: return self.minimum_background_batch_size - def get_current_update(self) -> Optional[BackgroundUpdatePerformance]: + def get_current_update(self) -> BackgroundUpdatePerformance | None: """Returns the current background update, if any.""" update_name = self._current_background_update @@ -526,14 +525,14 @@ class BackgroundUpdater: True if we have finished running all the background updates, otherwise False """ - def get_background_updates_txn(txn: Cursor) -> list[tuple[str, Optional[str]]]: + def get_background_updates_txn(txn: Cursor) -> list[tuple[str, str | None]]: txn.execute( """ SELECT update_name, depends_on FROM background_updates ORDER BY ordering, update_name """ ) - return cast(list[tuple[str, Optional[str]]], txn.fetchall()) + return cast(list[tuple[str, str | None]], txn.fetchall()) if not self._current_background_update: all_pending_updates = await self.db_pool.runInteraction( @@ -669,10 +668,10 @@ class BackgroundUpdater: index_name: str, table: str, columns: Iterable[str], - where_clause: Optional[str] = None, + where_clause: str | None = None, unique: bool = False, psql_only: bool = False, - replaces_index: Optional[str] = None, + replaces_index: str | None = None, ) -> None: """Helper for store classes to do a background index addition @@ -763,10 +762,10 @@ class BackgroundUpdater: index_name: str, table: str, columns: Iterable[str], - where_clause: Optional[str] = None, + where_clause: str | None = None, unique: bool = False, psql_only: bool = False, - replaces_index: Optional[str] = None, + replaces_index: str | None = None, ) -> None: """Add an index in the background. @@ -862,7 +861,7 @@ class BackgroundUpdater: c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner: Optional[Callable[[LoggingDatabaseConnection], None]] = ( + runner: Callable[[LoggingDatabaseConnection], None] | None = ( create_index_psql ) elif psql_only: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 0daf4830d9..2948227807 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -34,9 +34,7 @@ from typing import ( Generator, Generic, Iterable, - Optional, TypeVar, - Union, ) import attr @@ -164,7 +162,7 @@ class _UpdateCurrentStateTask: return isinstance(task, _UpdateCurrentStateTask) -_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask] +_EventPersistQueueTask = _PersistEventsTask | _UpdateCurrentStateTask _PersistResult = TypeVar("_PersistResult") @@ -674,7 +672,7 @@ class EventsPersistenceStorageController: async def _calculate_new_forward_extremities_and_state_delta( self, room_id: str, ev_ctx_rm: list[EventPersistencePair] - ) -> tuple[Optional[set[str]], Optional[DeltaState]]: + ) -> tuple[set[str] | None, DeltaState | None]: """Calculates the new forward extremities and state delta for a room given events to persist. @@ -861,7 +859,7 @@ class EventsPersistenceStorageController: events_context: list[EventPersistencePair], old_latest_event_ids: AbstractSet[str], new_latest_event_ids: set[str], - ) -> tuple[Optional[StateMap[str]], Optional[StateMap[str]], set[str]]: + ) -> tuple[StateMap[str] | None, StateMap[str] | None, set[str]]: """Calculate the current state dict after adding some new events to a room diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py index 6606fdcc30..4ca3f8f4e1 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -25,7 +25,6 @@ from typing import ( TYPE_CHECKING, Collection, Mapping, - Optional, ) from synapse.logging.context import nested_logging_context @@ -445,7 +444,7 @@ class PurgeEventsStorageController: # Remove state groups from deletion_candidates which are directly referenced or share a # future edge with a referenced state group within this batch. - def filter_reference_chains(group: Optional[int]) -> None: + def filter_reference_chains(group: int | None) -> None: while group is not None: deletion_candidates.discard(group) group = state_group_edges.get(group) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 690a0dde2e..9c5e837ab0 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -27,8 +27,6 @@ from typing import ( Collection, Iterable, Mapping, - Optional, - Union, ) from synapse.api.constants import EventTypes, Membership @@ -91,7 +89,7 @@ class StateStorageController: @tag_args async def get_state_group_delta( self, state_group: int - ) -> tuple[Optional[int], Optional[StateMap[str]]]: + ) -> tuple[int | None, StateMap[str] | None]: """Given a state group try to return a previous group and a delta between the old and the new. @@ -141,7 +139,7 @@ class StateStorageController: @trace @tag_args async def get_state_ids_for_group( - self, state_group: int, state_filter: Optional[StateFilter] = None + self, state_group: int, state_filter: StateFilter | None = None ) -> StateMap[str]: """Get the event IDs of all the state in the given state group @@ -217,7 +215,7 @@ class StateStorageController: @trace @tag_args async def get_state_for_events( - self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None + self, event_ids: Collection[str], state_filter: StateFilter | None = None ) -> dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -271,7 +269,7 @@ class StateStorageController: async def get_state_ids_for_events( self, event_ids: Collection[str], - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, ) -> dict[str, StateMap[str]]: """ @@ -322,7 +320,7 @@ class StateStorageController: @trace @tag_args async def get_state_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, event_id: str, state_filter: StateFilter | None = None ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -349,7 +347,7 @@ class StateStorageController: async def get_state_ids_for_event( self, event_id: str, - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, ) -> StateMap[str]: """ @@ -382,7 +380,7 @@ class StateStorageController: async def get_state_after_event( self, event_id: str, - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, ) -> StateMap[str]: """ @@ -423,7 +421,7 @@ class StateStorageController: self, room_id: str, stream_position: StreamToken, - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, ) -> StateMap[str]: """Get the room state at a particular stream position @@ -479,7 +477,7 @@ class StateStorageController: self, room_id: str, stream_position: StreamToken, - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, ) -> StateMap[EventBase]: """Same as `get_state_ids_at` but also fetches the events""" @@ -500,7 +498,7 @@ class StateStorageController: @trace @tag_args async def get_state_for_groups( - self, groups: Iterable[int], state_filter: Optional[StateFilter] = None + self, groups: Iterable[int], state_filter: StateFilter | None = None ) -> dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -546,9 +544,9 @@ class StateStorageController: self, event_id: str, room_id: str, - prev_group: Optional[int], - delta_ids: Optional[StateMap[str]], - current_state_ids: Optional[StateMap[str]], + prev_group: int | None, + delta_ids: StateMap[str] | None, + current_state_ids: StateMap[str] | None, ) -> int: """Store a new set of state, returning a newly assigned state group. @@ -575,9 +573,9 @@ class StateStorageController: async def get_current_state_ids( self, room_id: str, - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, - on_invalidate: Optional[Callable[[], None]] = None, + on_invalidate: Callable[[], None] | None = None, ) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. @@ -614,7 +612,7 @@ class StateStorageController: @trace @tag_args - async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: + async def get_canonical_alias_for_room(self, room_id: str) -> str | None: """Get canonical alias for room, if any Args: @@ -639,9 +637,7 @@ class StateStorageController: return event.content.get("alias") @cached() - async def get_server_acl_for_room( - self, room_id: str - ) -> Optional[ServerAclEvaluator]: + async def get_server_acl_for_room(self, room_id: str) -> ServerAclEvaluator | None: """Get the server ACL evaluator for room, if any This does up-front parsing of the content to ignore bad data and pre-compile @@ -695,7 +691,7 @@ class StateStorageController: async def get_current_state( self, room_id: str, - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, await_full_state: bool = True, ) -> StateMap[EventBase]: """Same as `get_current_state_ids` but also fetches the events""" @@ -717,7 +713,7 @@ class StateStorageController: @tag_args async def get_current_state_event( self, room_id: str, event_type: str, state_key: str - ) -> Optional[EventBase]: + ) -> EventBase | None: """Get the current state event for the given type/state_key.""" key = (event_type, state_key) @@ -804,7 +800,7 @@ class StateStorageController: async def get_joined_hosts( self, room_id: str, state_entry: "_StateCacheEntry" ) -> frozenset[str]: - state_group: Union[object, int] = state_entry.state_group + state_group: object | int = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -822,7 +818,7 @@ class StateStorageController: async def _get_joined_hosts( self, room_id: str, - state_group: Union[object, int], + state_group: object | int, state_entry: "_StateCacheEntry", ) -> frozenset[str]: # We don't use `state_group`, it's there so that we can cache based on diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b7f870bd26..3d351e8aea 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -35,7 +35,6 @@ from typing import ( Iterator, Literal, Mapping, - Optional, Sequence, TypeVar, cast, @@ -213,10 +212,10 @@ class LoggingDatabaseConnection: def cursor( self, *, - txn_name: Optional[str] = None, - after_callbacks: Optional[list["_CallbackListEntry"]] = None, - async_after_callbacks: Optional[list["_AsyncCallbackListEntry"]] = None, - exception_callbacks: Optional[list["_CallbackListEntry"]] = None, + txn_name: str | None = None, + after_callbacks: list["_CallbackListEntry"] | None = None, + async_after_callbacks: list["_AsyncCallbackListEntry"] | None = None, + exception_callbacks: list["_CallbackListEntry"] | None = None, ) -> "LoggingTransaction": if not txn_name: txn_name = self.default_txn_name @@ -246,10 +245,10 @@ class LoggingDatabaseConnection: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[types.TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: return self.conn.__exit__(exc_type, exc_value, traceback) # Proxy through any unknown lookups to the DB conn class. @@ -307,9 +306,9 @@ class LoggingTransaction: name: str, server_name: str, database_engine: BaseDatabaseEngine, - after_callbacks: Optional[list[_CallbackListEntry]] = None, - async_after_callbacks: Optional[list[_AsyncCallbackListEntry]] = None, - exception_callbacks: Optional[list[_CallbackListEntry]] = None, + after_callbacks: list[_CallbackListEntry] | None = None, + async_after_callbacks: list[_AsyncCallbackListEntry] | None = None, + exception_callbacks: list[_CallbackListEntry] | None = None, ): self.txn = txn self.name = name @@ -379,10 +378,10 @@ class LoggingTransaction: assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) - def fetchone(self) -> Optional[tuple]: + def fetchone(self) -> tuple | None: return self.txn.fetchone() - def fetchmany(self, size: Optional[int] = None) -> list[tuple]: + def fetchmany(self, size: int | None = None) -> list[tuple]: return self.txn.fetchmany(size=size) def fetchall(self) -> list[tuple]: @@ -398,7 +397,7 @@ class LoggingTransaction: @property def description( self, - ) -> Optional[Sequence[Any]]: + ) -> Sequence[Any] | None: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: @@ -429,7 +428,7 @@ class LoggingTransaction: self, sql: str, values: Iterable[Iterable[Any]], - template: Optional[str] = None, + template: str | None = None, fetch: bool = True, ) -> list[tuple]: """Corresponds to psycopg2.extras.execute_values. Only available when @@ -536,9 +535,9 @@ class LoggingTransaction: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[types.TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, ) -> None: self.close() @@ -920,7 +919,7 @@ class DatabasePool: func: Callable[..., R], *args: Any, db_autocommit: bool = False, - isolation_level: Optional[int] = None, + isolation_level: int | None = None, **kwargs: Any, ) -> R: """Starts a transaction on the database and runs a given function @@ -1002,7 +1001,7 @@ class DatabasePool: func: Callable[Concatenate[LoggingDatabaseConnection, P], R], *args: Any, db_autocommit: bool = False, - isolation_level: Optional[int] = None, + isolation_level: int | None = None, **kwargs: Any, ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -1240,8 +1239,8 @@ class DatabasePool: table: str, keyvalues: dict[str, Any], values: dict[str, Any], - insertion_values: Optional[dict[str, Any]] = None, - where_clause: Optional[str] = None, + insertion_values: dict[str, Any] | None = None, + where_clause: str | None = None, desc: str = "simple_upsert", ) -> bool: """Insert a row with values + insertion_values; on conflict, update with values. @@ -1334,8 +1333,8 @@ class DatabasePool: table: str, keyvalues: Mapping[str, Any], values: Mapping[str, Any], - insertion_values: Optional[Mapping[str, Any]] = None, - where_clause: Optional[str] = None, + insertion_values: Mapping[str, Any] | None = None, + where_clause: str | None = None, ) -> bool: """ Pick the UPSERT method which works best on the platform. Either the @@ -1379,8 +1378,8 @@ class DatabasePool: table: str, keyvalues: Mapping[str, Any], values: Mapping[str, Any], - insertion_values: Optional[Mapping[str, Any]] = None, - where_clause: Optional[str] = None, + insertion_values: Mapping[str, Any] | None = None, + where_clause: str | None = None, lock: bool = True, ) -> bool: """ @@ -1460,8 +1459,8 @@ class DatabasePool: table: str, keyvalues: Mapping[str, Any], values: Mapping[str, Any], - insertion_values: Optional[Mapping[str, Any]] = None, - where_clause: Optional[str] = None, + insertion_values: Mapping[str, Any] | None = None, + where_clause: str | None = None, ) -> bool: """ Use the native UPSERT functionality in PostgreSQL. @@ -1728,7 +1727,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", - ) -> Optional[tuple[Any, ...]]: ... + ) -> tuple[Any, ...] | None: ... async def simple_select_one( self, @@ -1737,7 +1736,7 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> Optional[tuple[Any, ...]]: + ) -> tuple[Any, ...] | None: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1777,7 +1776,7 @@ class DatabasePool: retcol: str, allow_none: Literal[True] = True, desc: str = "simple_select_one_onecol", - ) -> Optional[Any]: ... + ) -> Any | None: ... async def simple_select_one_onecol( self, @@ -1786,7 +1785,7 @@ class DatabasePool: retcol: str, allow_none: bool = False, desc: str = "simple_select_one_onecol", - ) -> Optional[Any]: + ) -> Any | None: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -1828,7 +1827,7 @@ class DatabasePool: keyvalues: dict[str, Any], retcol: str, allow_none: Literal[True] = True, - ) -> Optional[Any]: ... + ) -> Any | None: ... @classmethod def simple_select_one_onecol_txn( @@ -1838,7 +1837,7 @@ class DatabasePool: keyvalues: dict[str, Any], retcol: str, allow_none: bool = False, - ) -> Optional[Any]: + ) -> Any | None: ret = cls.simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol ) @@ -1871,7 +1870,7 @@ class DatabasePool: async def simple_select_onecol( self, table: str, - keyvalues: Optional[dict[str, Any]], + keyvalues: dict[str, Any] | None, retcol: str, desc: str = "simple_select_onecol", ) -> list[Any]: @@ -1899,7 +1898,7 @@ class DatabasePool: async def simple_select_list( self, table: str, - keyvalues: Optional[dict[str, Any]], + keyvalues: dict[str, Any] | None, retcols: Collection[str], desc: str = "simple_select_list", ) -> list[tuple[Any, ...]]: @@ -1931,7 +1930,7 @@ class DatabasePool: cls, txn: LoggingTransaction, table: str, - keyvalues: Optional[dict[str, Any]], + keyvalues: dict[str, Any] | None, retcols: Iterable[str], ) -> list[tuple[Any, ...]]: """Executes a SELECT query on the named table, which may return zero or @@ -1967,7 +1966,7 @@ class DatabasePool: column: str, iterable: Iterable[Any], retcols: Collection[str], - keyvalues: Optional[dict[str, Any]] = None, + keyvalues: dict[str, Any] | None = None, desc: str = "simple_select_many_batch", batch_size: int = 100, ) -> list[tuple[Any, ...]]: @@ -2249,7 +2248,7 @@ class DatabasePool: keyvalues: dict[str, Any], retcols: Collection[str], allow_none: Literal[True] = True, - ) -> Optional[tuple[Any, ...]]: ... + ) -> tuple[Any, ...] | None: ... @staticmethod def simple_select_one_txn( @@ -2258,7 +2257,7 @@ class DatabasePool: keyvalues: dict[str, Any], retcols: Collection[str], allow_none: bool = False, - ) -> Optional[tuple[Any, ...]]: + ) -> tuple[Any, ...] | None: select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) if keyvalues: @@ -2529,9 +2528,9 @@ class DatabasePool: start: int, limit: int, retcols: Iterable[str], - filters: Optional[dict[str, Any]] = None, - keyvalues: Optional[dict[str, Any]] = None, - exclude_keyvalues: Optional[dict[str, Any]] = None, + filters: dict[str, Any] | None = None, + keyvalues: dict[str, Any] | None = None, + exclude_keyvalues: dict[str, Any] | None = None, order_direction: str = "ASC", ) -> list[tuple[Any, ...]]: """ diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index f145d21096..b44b84b913 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.storage._base import SQLBaseStore @@ -64,7 +64,7 @@ class Databases(Generic[DataStoreT]): databases: list[DatabasePool] main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class` state: StateGroupDataStore - persist_events: Optional[PersistEventsStore] + persist_events: PersistEventsStore | None state_deletion: StateDeletionDataStore def __init__(self, main_store_class: type[DataStoreT], hs: "HomeServer"): @@ -72,10 +72,10 @@ class Databases(Generic[DataStoreT]): # store. self.databases = [] - main: Optional[DataStoreT] = None - state: Optional[StateGroupDataStore] = None - state_deletion: Optional[StateDeletionDataStore] = None - persist_events: Optional[PersistEventsStore] = None + main: DataStoreT | None = None + state: StateGroupDataStore | None = None + state_deletion: StateDeletionDataStore | None = None + persist_events: PersistEventsStore | None = None server_name = hs.hostname diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9f23c1a4e0..12593094f1 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -20,7 +20,7 @@ # # import logging -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, cast import attr @@ -99,14 +99,14 @@ class UserPaginateResponse: """This is very similar to UserInfo, but not quite the same.""" name: str - user_type: Optional[str] + user_type: str | None is_guest: bool admin: bool deactivated: bool shadow_banned: bool - displayname: Optional[str] - avatar_url: Optional[str] - creation_ts: Optional[int] + displayname: str | None + avatar_url: str | None + creation_ts: int | None approved: bool erased: bool last_seen_ts: int @@ -180,15 +180,15 @@ class DataStore( self, start: int, limit: int, - user_id: Optional[str] = None, - name: Optional[str] = None, + user_id: str | None = None, + name: str | None = None, guests: bool = True, - deactivated: Optional[bool] = None, - admins: Optional[bool] = None, + deactivated: bool | None = None, + admins: bool | None = None, order_by: str = UserSortOrder.NAME.value, direction: Direction = Direction.FORWARDS, approved: bool = True, - not_user_types: Optional[list[str]] = None, + not_user_types: list[str] | None = None, locked: bool = False, ) -> tuple[list[UserPaginateResponse], int]: """Function to retrieve a paginated list of users from @@ -351,9 +351,7 @@ class DataStore( async def search_users( self, term: str - ) -> list[ - tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]] - ]: + ) -> list[tuple[str, str | None, int | bool, int | bool, str | None]]: """Function to search users list for one or more users with the matched term. @@ -366,9 +364,7 @@ class DataStore( def search_users( txn: LoggingTransaction, - ) -> list[ - tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]] - ]: + ) -> list[tuple[str, str | None, int | bool, int | bool, str | None]]: search_term = "%%" + term + "%%" sql = """ @@ -382,10 +378,10 @@ class DataStore( list[ tuple[ str, - Optional[str], - Union[int, bool], - Union[int, bool], - Optional[str], + str | None, + int | bool, + int | bool, + str | None, ] ], txn.fetchall(), diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index f1fb5fe188..15728cf618 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -25,7 +25,6 @@ from typing import ( Any, Iterable, Mapping, - Optional, cast, ) @@ -213,7 +212,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=2, max_entries=5000, tree=True) async def get_global_account_data_by_type_for_user( self, user_id: str, data_type: str - ) -> Optional[JsonMapping]: + ) -> JsonMapping | None: """ Returns: The account data. @@ -233,7 +232,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_latest_stream_id_for_global_account_data_by_type_for_user( self, user_id: str, data_type: str - ) -> Optional[int]: + ) -> int | None: """ Returns: The stream ID of the account data, @@ -242,7 +241,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_latest_stream_id_for_global_account_data_by_type_for_user_txn( txn: LoggingTransaction, - ) -> Optional[int]: + ) -> int | None: sql = """ SELECT stream_id FROM account_data WHERE user_id = ? AND account_data_type = ? @@ -300,7 +299,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=3, max_entries=5000, tree=True) async def get_account_data_for_room_and_type( self, user_id: str, room_id: str, account_data_type: str - ) -> Optional[JsonMapping]: + ) -> JsonMapping | None: """Get the client account_data of given type for a user for a room. Args: @@ -313,7 +312,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_account_data_for_room_and_type_txn( txn: LoggingTransaction, - ) -> Optional[JsonDict]: + ) -> JsonDict | None: content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 7558672905..6c2bf90b37 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -20,7 +20,7 @@ # import logging import re -from typing import TYPE_CHECKING, Optional, Pattern, Sequence, cast +from typing import TYPE_CHECKING, Pattern, Sequence, cast from synapse.appservice import ( ApplicationService, @@ -53,7 +53,7 @@ logger = logging.getLogger(__name__) def _make_exclusive_regex( services_cache: list[ApplicationService], -) -> Optional[Pattern]: +) -> Pattern | None: # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ @@ -63,7 +63,7 @@ def _make_exclusive_regex( ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex) + exclusive_user_pattern: Pattern | None = re.compile(exclusive_user_regex) else: # We handle this case specially otherwise the constructed regex # will always match @@ -116,7 +116,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): else: return False - def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]: + def get_app_service_by_user_id(self, user_id: str) -> ApplicationService | None: """Retrieve an application service from their user ID. All application services have associated with them a particular user ID. @@ -134,7 +134,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): return service return None - def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]: + def get_app_service_by_token(self, token: str) -> ApplicationService | None: """Get the application service with the given appservice token. Args: @@ -147,7 +147,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): return service return None - def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: + def get_app_service_by_id(self, as_id: str) -> ApplicationService | None: """Get the application service with the given appservice ID. Args: @@ -227,7 +227,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_appservice_state( self, service: ApplicationService - ) -> Optional[ApplicationServiceState]: + ) -> ApplicationServiceState | None: """Get the application service state. Args: @@ -347,7 +347,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_oldest_unsent_txn( self, service: ApplicationService - ) -> Optional[AppServiceTransaction]: + ) -> AppServiceTransaction | None: """Get the oldest transaction which has not been sent for this service. Args: @@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore( def _get_oldest_unsent_txn( txn: LoggingTransaction, - ) -> Optional[tuple[int, str]]: + ) -> tuple[int, str] | None: # Monotonically increasing txn ids, so just select the smallest # one in the txns table (we delete them when they are sent) txn.execute( @@ -366,7 +366,7 @@ class ApplicationServiceTransactionWorkerStore( " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - return cast(Optional[tuple[int, str]], txn.fetchone()) + return cast(tuple[int, str] | None, txn.fetchone()) entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn @@ -447,7 +447,7 @@ class ApplicationServiceTransactionWorkerStore( ) async def set_appservice_stream_type_pos( - self, service: ApplicationService, stream_type: str, pos: Optional[int] + self, service: ApplicationService, stream_type: str, pos: int | None ) -> None: if stream_type not in ("read_receipt", "presence", "to_device", "device_list"): raise ValueError( diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 5a96510b13..b7b9b42461 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,7 +23,7 @@ import itertools import json import logging -from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional +from typing import TYPE_CHECKING, Any, Collection, Iterable from synapse.api.constants import EventTypes from synapse.config._base import Config @@ -104,7 +104,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): psql_only=True, # The table is only on postgres DBs. ) - self._cache_id_gen: Optional[MultiWriterIdGenerator] + self._cache_id_gen: MultiWriterIdGenerator | None if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our @@ -381,9 +381,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): event_id: str, room_id: str, etype: str, - state_key: Optional[str], - redacts: Optional[str], - relates_to: Optional[str], + state_key: str | None, + redacts: str | None, + relates_to: str | None, backfilled: bool, ) -> None: # This is needed to avoid a circular import. @@ -699,7 +699,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) async def send_invalidation_to_replication( - self, cache_name: str, keys: Optional[Collection[Any]] + self, cache_name: str, keys: Collection[Any] | None ) -> None: await self.db_pool.runInteraction( "send_invalidation_to_replication", @@ -709,7 +709,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def _send_invalidation_to_replication( - self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + self, txn: LoggingTransaction, cache_name: str, keys: Iterable[Any] | None ) -> None: """Notifies replication that given cache has been invalidated. diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 45cfe97dba..5d667a5345 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -121,7 +121,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json: Optional[str] = json_encoder.encode( + pruned_json: str | None = json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 1033d85a40..4948d0c286 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -23,9 +23,7 @@ import logging from typing import ( TYPE_CHECKING, Mapping, - Optional, TypedDict, - Union, cast, ) @@ -64,9 +62,9 @@ class DeviceLastConnectionInfo: user_id: str device_id: str - ip: Optional[str] - user_agent: Optional[str] - last_seen: Optional[int] + ip: str | None + user_agent: str | None + last_seen: int | None class LastConnectionInfo(TypedDict): @@ -176,7 +174,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # Fetch the start of the batch begin_last_seen: int = progress.get("last_seen", 0) - def get_last_seen(txn: LoggingTransaction) -> Optional[int]: + def get_last_seen(txn: LoggingTransaction) -> int | None: txn.execute( """ SELECT last_seen FROM user_ips @@ -187,7 +185,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): """, (begin_last_seen, batch_size), ) - row = cast(Optional[tuple[int]], txn.fetchone()) + row = cast(tuple[int] | None, txn.fetchone()) if row: return row[0] else: @@ -248,7 +246,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): args, ) res = cast( - list[tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall() + list[tuple[str, str, str, str | None, str, int, int]], txn.fetchall() ) # We've got some duplicates @@ -358,7 +356,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. - where_args: list[Union[str, int]] + where_args: list[str | int] where_clause, where_args = make_tuple_comparison_clause( [("user_id", last_user_id), ("device_id", last_device_id)], ) @@ -447,7 +445,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update: dict[ - tuple[str, str, str], tuple[str, Optional[str], int] + tuple[str, str, str], tuple[str, str | None, int] ] = {} self.clock.looping_call(self._update_client_ips_batch, 5 * 1000) @@ -500,7 +498,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke ) async def _get_last_client_ip_by_device_from_database( - self, user_id: str, device_id: Optional[str] + self, user_id: str, device_id: str | None ) -> dict[tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on. @@ -519,7 +517,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke keyvalues["device_id"] = device_id res = cast( - list[tuple[str, Optional[str], Optional[str], str, Optional[int]]], + list[tuple[str, str | None, str | None, str, int | None]], await self.db_pool.simple_select_list( table="devices", keyvalues=keyvalues, @@ -596,8 +594,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke access_token: str, ip: str, user_agent: str, - device_id: Optional[str], - now: Optional[int] = None, + device_id: str | None, + now: int | None = None, ) -> None: """Record that `user_id` used `access_token` from this `ip` address. @@ -670,7 +668,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke def _update_client_ips_batch_txn( self, txn: LoggingTransaction, - to_update: Mapping[tuple[str, str, str], tuple[str, Optional[str], int]], + to_update: Mapping[tuple[str, str, str], tuple[str, str | None, int]], ) -> None: assert self._update_on_this_worker, ( "This worker is not designated to update client IPs" @@ -715,7 +713,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke ) async def get_last_client_ip_by_device( - self, user_id: str, device_id: Optional[str] + self, user_id: str, device_id: str | None ) -> dict[tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on @@ -805,7 +803,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke return list(results.values()) - async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]: + async def get_last_seen_for_user_id(self, user_id: str) -> int | None: """Get the last seen timestamp for a user, if we have it.""" return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 6ad161db33..b11ed86db2 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -13,7 +13,7 @@ # import logging -from typing import NewType, Optional +from typing import NewType import attr @@ -42,10 +42,10 @@ Timestamp = NewType("Timestamp", int) class EventDetails: room_id: RoomID type: EventType - state_key: Optional[StateKey] - origin_server_ts: Optional[Timestamp] + state_key: StateKey | None + origin_server_ts: Timestamp | None content: JsonDict - device_id: Optional[DeviceID] + device_id: DeviceID | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -67,7 +67,7 @@ class DelayedEventsStore(SQLBaseStore): desc="get_delayed_events_stream_pos", ) - async def update_delayed_events_stream_pos(self, stream_id: Optional[int]) -> None: + async def update_delayed_events_stream_pos(self, stream_id: int | None) -> None: """ Updates the stream position of the background process to watch for state events that target the same piece of state as any pending delayed events. @@ -85,12 +85,12 @@ class DelayedEventsStore(SQLBaseStore): self, *, user_localpart: str, - device_id: Optional[str], + device_id: str | None, creation_ts: Timestamp, room_id: str, event_type: str, - state_key: Optional[str], - origin_server_ts: Optional[int], + state_key: str | None, + origin_server_ts: int | None, content: JsonDict, delay: int, ) -> tuple[DelayID, Timestamp]: @@ -238,7 +238,7 @@ class DelayedEventsStore(SQLBaseStore): self, current_ts: Timestamp ) -> tuple[ list[DelayedEventDetails], - Optional[Timestamp], + Timestamp | None, ]: """ Marks for processing all delayed events that should have been sent prior to the provided time @@ -252,7 +252,7 @@ class DelayedEventsStore(SQLBaseStore): txn: LoggingTransaction, ) -> tuple[ list[DelayedEventDetails], - Optional[Timestamp], + Timestamp | None, ]: sql_cols = ", ".join( ( @@ -324,7 +324,7 @@ class DelayedEventsStore(SQLBaseStore): user_localpart: str, ) -> tuple[ EventDetails, - Optional[Timestamp], + Timestamp | None, ]: """ Marks for processing the matching delayed event, regardless of its timeout time, @@ -345,7 +345,7 @@ class DelayedEventsStore(SQLBaseStore): txn: LoggingTransaction, ) -> tuple[ EventDetails, - Optional[Timestamp], + Timestamp | None, ]: txn.execute( """ @@ -390,7 +390,7 @@ class DelayedEventsStore(SQLBaseStore): *, delay_id: str, user_localpart: str, - ) -> Optional[Timestamp]: + ) -> Timestamp | None: """ Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. @@ -406,7 +406,7 @@ class DelayedEventsStore(SQLBaseStore): def cancel_delayed_event_txn( txn: LoggingTransaction, - ) -> Optional[Timestamp]: + ) -> Timestamp | None: try: self.db_pool.simple_delete_one_txn( txn, @@ -436,7 +436,7 @@ class DelayedEventsStore(SQLBaseStore): event_type: str, state_key: str, not_from_localpart: str, - ) -> Optional[Timestamp]: + ) -> Timestamp | None: """ Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. @@ -452,7 +452,7 @@ class DelayedEventsStore(SQLBaseStore): def cancel_delayed_state_events_txn( txn: LoggingTransaction, - ) -> Optional[Timestamp]: + ) -> Timestamp | None: txn.execute( """ DELETE FROM delayed_events @@ -526,7 +526,7 @@ class DelayedEventsStore(SQLBaseStore): desc="unprocess_delayed_events", ) - async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]: + async def get_next_delayed_event_send_ts(self) -> Timestamp | None: """ Returns the send time of the next delayed event to be sent, if any. """ @@ -538,7 +538,7 @@ class DelayedEventsStore(SQLBaseStore): def _get_next_delayed_event_send_ts_txn( self, txn: LoggingTransaction - ) -> Optional[Timestamp]: + ) -> Timestamp | None: result = self.db_pool.simple_select_one_onecol_txn( txn, table="delayed_events", diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 49a82b98d3..a12411d723 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -25,7 +25,6 @@ from typing import ( TYPE_CHECKING, Collection, Iterable, - Optional, cast, ) @@ -87,15 +86,15 @@ class DeviceInboxWorkerStore(SQLBaseStore): # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. - self._last_device_delete_cache: ExpiringCache[ - tuple[str, Optional[str]], int - ] = ExpiringCache( - cache_name="last_device_delete_cache", - server_name=self.server_name, - hs=hs, - clock=self.clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, + self._last_device_delete_cache: ExpiringCache[tuple[str, str | None], int] = ( + ExpiringCache( + cache_name="last_device_delete_cache", + server_name=self.server_name, + hs=hs, + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) ) self._can_write_to_device = ( @@ -469,7 +468,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): async def delete_messages_for_device( self, user_id: str, - device_id: Optional[str], + device_id: str | None, up_to_stream_id: int, ) -> int: """ @@ -527,11 +526,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): async def delete_messages_for_device_between( self, user_id: str, - device_id: Optional[str], - from_stream_id: Optional[int], + device_id: str | None, + from_stream_id: int | None, to_stream_id: int, limit: int, - ) -> tuple[Optional[int], int]: + ) -> tuple[int | None, int]: """Delete N device messages between the stream IDs, returning the highest stream ID deleted (or None if all messages in the range have been deleted) and the number of messages deleted. @@ -551,7 +550,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): def delete_messages_for_device_between_txn( txn: LoggingTransaction, - ) -> tuple[Optional[int], int]: + ) -> tuple[int | None, int]: txn.execute( """ SELECT MAX(stream_id) FROM ( @@ -1147,7 +1146,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): # There's a type mismatch here between how we want to type the row and # what fetchone says it returns, but we silence it because we know that # res can't be None. - res = cast(tuple[Optional[int]], txn.fetchone()) + res = cast(tuple[int | None], txn.fetchone()) if res[0] is None: # this can only happen if the `device_inbox` table is empty, in which # case we have no work to do. @@ -1210,7 +1209,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): max_stream_id = progress["max_stream_id"] else: txn.execute("SELECT max(stream_id) FROM device_federation_outbox") - res = cast(tuple[Optional[int]], txn.fetchone()) + res = cast(tuple[int | None], txn.fetchone()) if res[0] is None: # this can only happen if the `device_inbox` table is empty, in which # case we have no work to do. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index bf5e05ea51..caae2a0648 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -26,7 +26,6 @@ from typing import ( Collection, Iterable, Mapping, - Optional, cast, ) @@ -254,7 +253,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return self._device_list_id_gen async def count_devices_by_users( - self, user_ids: Optional[Collection[str]] = None + self, user_ids: Collection[str] | None = None ) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -293,9 +292,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self, user_id: str, device_id: str, - initial_device_display_name: Optional[str], - auth_provider_id: Optional[str] = None, - auth_provider_session_id: Optional[str] = None, + initial_device_display_name: str | None, + auth_provider_id: str | None = None, + auth_provider_session_id: str | None = None, ) -> bool: """Ensure the given device is known; add it to the store if not @@ -441,7 +440,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) async def update_device( - self, user_id: str, device_id: str, new_display_name: Optional[str] = None + self, user_id: str, device_id: str, new_display_name: str | None = None ) -> None: """Update a device. Only updates the device if it is not marked as hidden. @@ -469,7 +468,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cached(tree=True) async def get_device( self, user_id: str, device_id: str - ) -> Optional[Mapping[str, Any]]: + ) -> Mapping[str, Any] | None: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -493,7 +492,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): async def get_devices_by_user( self, user_id: str - ) -> dict[str, dict[str, Optional[str]]]: + ) -> dict[str, dict[str, str | None]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. @@ -504,7 +503,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): and "display_name" for each device. Display name may be null. """ devices = cast( - list[tuple[str, str, Optional[str]]], + list[tuple[str, str, str | None]], await self.db_pool.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, @@ -655,7 +654,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): last_processed_stream_id = from_stream_id # A map of (user ID, device ID) to (stream ID, context). - query_map: dict[tuple[str, str], tuple[int, Optional[str]]] = {} + query_map: dict[tuple[str, str], tuple[int, str | None]] = {} cross_signing_keys_by_user: dict[str, dict[str, object]] = {} for user_id, device_id, update_stream_id, update_context in updates: # Calculate the remaining length budget. @@ -762,7 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): from_stream_id: int, now_stream_id: int, limit: int, - ) -> list[tuple[str, str, int, Optional[str]]]: + ) -> list[tuple[str, str, int, str | None]]: """Return device update information for a given remote destination Args: @@ -788,13 +787,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) - return cast(list[tuple[str, str, int, Optional[str]]], txn.fetchall()) + return cast(list[tuple[str, str, int, str | None]], txn.fetchall()) async def _get_device_update_edus_by_remote( self, destination: str, from_stream_id: int, - query_map: dict[tuple[str, str], tuple[int, Optional[str]]], + query_map: dict[tuple[str, str], tuple[int, str | None]], ) -> list[tuple[str, dict]]: """Returns a list of device update EDUs as well as E2EE keys @@ -1126,7 +1125,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self, from_key: MultiWriterStreamToken, user_ids: Collection[str], - to_key: Optional[MultiWriterStreamToken] = None, + to_key: MultiWriterStreamToken | None = None, ) -> set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. @@ -1298,7 +1297,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cached(max_entries=10000) async def get_device_list_last_stream_id_for_remote( self, user_id: str - ) -> Optional[str]: + ) -> str | None: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ @@ -1316,7 +1315,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) async def get_device_list_last_stream_id_for_remotes( self, user_ids: Iterable[str] - ) -> Mapping[str, Optional[str]]: + ) -> Mapping[str, str | None]: rows = cast( list[tuple[str, str]], await self.db_pool.simple_select_many_batch( @@ -1328,14 +1327,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ), ) - results: dict[str, Optional[str]] = dict.fromkeys(user_ids) + results: dict[str, str | None] = dict.fromkeys(user_ids) results.update(rows) return results async def get_user_ids_requiring_device_list_resync( self, - user_ids: Optional[Collection[str]] = None, + user_ids: Collection[str] | None = None, ) -> set[str]: """Given a list of remote users return the list of users that we should resync the device lists for. If None is given instead of a list, @@ -1457,9 +1456,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[tuple[str, JsonDict]]: + async def get_dehydrated_device(self, user_id: str) -> tuple[str, JsonDict] | None: """Retrieve the information for a dehydrated device. Args: @@ -1484,8 +1481,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): device_id: str, device_data: str, time: int, - keys: Optional[JsonDict] = None, - ) -> Optional[str]: + keys: JsonDict | None = None, + ) -> str | None: # TODO: make keys non-optional once support for msc2697 is dropped if keys: device_keys = keys.get("device_keys", None) @@ -1534,8 +1531,8 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): device_id: str, device_data: JsonDict, time_now: int, - keys: Optional[dict] = None, - ) -> Optional[str]: + keys: dict | None = None, + ) -> str | None: """Store a dehydrated device for a user. Args: @@ -1724,7 +1721,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): room_ids: Collection[str], from_token: MultiWriterStreamToken, to_token: MultiWriterStreamToken, - ) -> Optional[set[str]]: + ) -> set[str] | None: """Return the set of users whose devices have changed in the given rooms since the given stream ID. @@ -1963,7 +1960,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_id: str, device_ids: StrCollection, room_ids: StrCollection, - ) -> Optional[int]: + ) -> int | None: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -2012,7 +2009,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return stream_ids[-1] - last_stream_id: Optional[int] = None + last_stream_id: int | None = None for batch_device_ids in batch_iter(device_ids, 1000): last_stream_id = await self.db_pool.runInteraction( "add_device_change_to_stream", @@ -2072,7 +2069,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): device_id: str, hosts: Collection[str], stream_id: int, - context: Optional[dict[str, str]], + context: dict[str, str] | None, ) -> None: if self._device_list_federation_stream_cache: for host in hosts: @@ -2204,7 +2201,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): async def get_uncoverted_outbound_room_pokes( self, start_stream_id: int, start_room_id: str, limit: int = 10 - ) -> list[tuple[str, str, str, int, Optional[dict[str, str]]]]: + ) -> list[tuple[str, str, str, int, dict[str, str] | None]]: """Get device list changes by room that have not yet been handled and written to `device_lists_outbound_pokes`. @@ -2232,7 +2229,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_uncoverted_outbound_room_pokes_txn( txn: LoggingTransaction, - ) -> list[tuple[str, str, str, int, Optional[dict[str, str]]]]: + ) -> list[tuple[str, str, str, int, dict[str, str] | None]]: txn.execute( sql, ( @@ -2266,7 +2263,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): device_id: str, room_id: str, hosts: Collection[str], - context: Optional[dict[str, str]], + context: dict[str, str] | None, ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 99a951ca4a..5e14ad4480 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -19,7 +19,7 @@ # # -from typing import Iterable, Optional, Sequence +from typing import Iterable, Sequence import attr @@ -40,7 +40,7 @@ class RoomAliasMapping: class DirectoryWorkerStore(CacheInvalidationWorkerStore): async def get_association_from_room_alias( self, room_alias: RoomAlias - ) -> Optional[RoomAliasMapping]: + ) -> RoomAliasMapping | None: """Gets the room_id and server list for a given room_alias Args: @@ -94,7 +94,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): room_alias: RoomAlias, room_id: str, servers: Iterable[str], - creator: Optional[str] = None, + creator: str | None = None, ) -> None: """Creates an association between a room alias and room_id/servers @@ -136,7 +136,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): 409, "Room alias %s already exists" % room_alias.to_string() ) - async def delete_room_alias(self, room_alias: RoomAlias) -> Optional[str]: + async def delete_room_alias(self, room_alias: RoomAlias) -> str | None: room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) @@ -145,7 +145,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): def _delete_room_alias_txn( self, txn: LoggingTransaction, room_alias: RoomAlias - ) -> Optional[str]: + ) -> str | None: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), @@ -174,7 +174,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): self, old_room_id: str, new_room_id: str, - creator: Optional[str] = None, + creator: str | None = None, ) -> None: """Repoint all of the aliases for a given room, to a different room. diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index a4d03d1d90..01e9fb4dcf 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -24,7 +24,6 @@ from typing import ( Iterable, Literal, Mapping, - Optional, TypedDict, cast, ) @@ -252,8 +251,8 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): self, user_id: str, version: str, - room_id: Optional[str] = None, - session_id: Optional[str] = None, + room_id: str | None = None, + session_id: str | None = None, ) -> dict[ Literal["rooms"], dict[str, dict[Literal["sessions"], dict[str, RoomKey]]] ]: @@ -438,8 +437,8 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): self, user_id: str, version: str, - room_id: Optional[str] = None, - session_id: Optional[str] = None, + room_id: str | None = None, + session_id: str | None = None, ) -> None: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -480,13 +479,13 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): ) # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will # be `NULL` when there are no available versions. - row = cast(tuple[Optional[int]], txn.fetchone()) + row = cast(tuple[int | None], txn.fetchone()) if row[0] is None: raise StoreError(404, "No current backup version") return row[0] async def get_e2e_room_keys_version_info( - self, user_id: str, version: Optional[str] = None + self, user_id: str, version: str | None = None ) -> JsonDict: """Get info metadata about a version of our room_keys backup. @@ -556,7 +555,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", (user_id,), ) - current_version = cast(tuple[Optional[int]], txn.fetchone())[0] + current_version = cast(tuple[int | None], txn.fetchone())[0] if current_version is None: current_version = 0 @@ -584,8 +583,8 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): self, user_id: str, version: str, - info: Optional[JsonDict] = None, - version_etag: Optional[int] = None, + info: JsonDict | None = None, + version_etag: int | None = None, ) -> None: """Update a given backup version @@ -621,7 +620,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): @trace async def delete_e2e_room_keys_version( - self, user_id: str, version: Optional[str] = None + self, user_id: str, version: str | None = None ) -> None: """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 991d64db44..c93ebd3dda 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -28,9 +28,7 @@ from typing import ( Iterable, Literal, Mapping, - Optional, Sequence, - Union, cast, overload, ) @@ -71,12 +69,12 @@ if TYPE_CHECKING: class DeviceKeyLookupResult: """The type returned by get_e2e_device_keys_and_signatures""" - display_name: Optional[str] + display_name: str | None # the key data from e2e_device_keys_json. Typically includes fields like # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.) - keys: Optional[JsonDict] + keys: JsonDict | None class EndToEndKeyBackgroundStore(SQLBaseStore): @@ -237,7 +235,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_device_keys_for_cs_api( self, - query_list: Collection[tuple[str, Optional[str]]], + query_list: Collection[tuple[str, str | None]], include_displaynames: bool = True, ) -> dict[str, dict[str, JsonDict]]: """Fetch a list of device keys, formatted suitably for the C/S API. @@ -280,14 +278,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @overload async def get_e2e_device_keys_and_signatures( self, - query_list: Collection[tuple[str, Optional[str]]], + query_list: Collection[tuple[str, str | None]], include_all_devices: Literal[False] = False, ) -> dict[str, dict[str, DeviceKeyLookupResult]]: ... @overload async def get_e2e_device_keys_and_signatures( self, - query_list: Collection[tuple[str, Optional[str]]], + query_list: Collection[tuple[str, str | None]], include_all_devices: bool = False, include_deleted_devices: Literal[False] = False, ) -> dict[str, dict[str, DeviceKeyLookupResult]]: ... @@ -295,22 +293,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @overload async def get_e2e_device_keys_and_signatures( self, - query_list: Collection[tuple[str, Optional[str]]], + query_list: Collection[tuple[str, str | None]], include_all_devices: Literal[True], include_deleted_devices: Literal[True], - ) -> dict[str, dict[str, Optional[DeviceKeyLookupResult]]]: ... + ) -> dict[str, dict[str, DeviceKeyLookupResult | None]]: ... @trace @cancellable async def get_e2e_device_keys_and_signatures( self, - query_list: Collection[tuple[str, Optional[str]]], + query_list: Collection[tuple[str, str | None]], include_all_devices: bool = False, include_deleted_devices: bool = False, - ) -> Union[ - dict[str, dict[str, DeviceKeyLookupResult]], - dict[str, dict[str, Optional[DeviceKeyLookupResult]]], - ]: + ) -> ( + dict[str, dict[str, DeviceKeyLookupResult]] + | dict[str, dict[str, DeviceKeyLookupResult | None]] + ): """Fetch a list of device keys Any cross-signatures made on the keys by the owner of the device are also @@ -384,10 +382,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def _get_e2e_device_keys( self, - query_list: Collection[tuple[str, Optional[str]]], + query_list: Collection[tuple[str, str | None]], include_all_devices: bool = False, include_deleted_devices: bool = False, - ) -> dict[str, dict[str, Optional[DeviceKeyLookupResult]]]: + ) -> dict[str, dict[str, DeviceKeyLookupResult | None]]: """Get information on devices from the database The results include the device's keys and self-signatures, but *not* any @@ -433,7 +431,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker query_clauses.append(user_device_id_in_list_clause) query_params_list.append(user_device_args) - result: dict[str, dict[str, Optional[DeviceKeyLookupResult]]] = {} + result: dict[str, dict[str, DeviceKeyLookupResult | None]] = {} def get_e2e_device_keys_txn( txn: LoggingTransaction, query_clause: str, query_params: list @@ -897,8 +895,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) async def get_e2e_cross_signing_key( - self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[JsonMapping]: + self, user_id: str, key_type: str, from_user_id: str | None = None + ) -> JsonMapping | None: """Returns a user's cross-signing key. Args: @@ -934,7 +932,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: + ) -> Mapping[str, Mapping[str, JsonMapping] | None]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -1013,9 +1011,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _get_e2e_cross_signing_signatures_txn( self, txn: LoggingTransaction, - keys: dict[str, Optional[dict[str, JsonDict]]], + keys: dict[str, dict[str, JsonDict] | None], from_user_id: str, - ) -> dict[str, Optional[dict[str, JsonDict]]]: + ) -> dict[str, dict[str, JsonDict] | None]: """Returns the cross-signing signatures made by a user on a set of keys. Args: @@ -1096,8 +1094,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_cross_signing_keys_bulk( - self, user_ids: list[str], from_user_id: Optional[str] = None - ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: + self, user_ids: list[str], from_user_id: str | None = None + ) -> Mapping[str, Mapping[str, JsonMapping] | None]: """Returns the cross-signing keys for a set of users. Args: @@ -1114,7 +1112,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if from_user_id: result = cast( - dict[str, Optional[Mapping[str, JsonMapping]]], + dict[str, Mapping[str, JsonMapping] | None], await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, @@ -1478,7 +1476,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def get_master_cross_signing_key_updatable_before( self, user_id: str - ) -> tuple[bool, Optional[int]]: + ) -> tuple[bool, int | None]: """Get time before which a master cross-signing key may be replaced without UIA. (UIA means "User-Interactive Auth".) @@ -1499,7 +1497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ - def impl(txn: LoggingTransaction) -> tuple[bool, Optional[int]]: + def impl(txn: LoggingTransaction) -> tuple[bool, int | None]: # We want to distinguish between three cases: txn.execute( """ @@ -1511,7 +1509,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """, (user_id,), ) - row = cast(Optional[tuple[Optional[int]]], txn.fetchone()) + row = cast(tuple[int | None] | None, txn.fetchone()) if row is None: return False, None return True, row[0] @@ -1571,7 +1569,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def allow_master_cross_signing_key_replacement_without_uia( self, user_id: str, duration_ms: int - ) -> Optional[int]: + ) -> int | None: """Mark this user's latest master key as being replaceable without UIA. Said replacement will only be permitted for a short time after calling this @@ -1583,7 +1581,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ timestamp = self.clock.time_msec() + duration_ms - def impl(txn: LoggingTransaction) -> Optional[int]: + def impl(txn: LoggingTransaction) -> int | None: txn.execute( """ UPDATE e2e_cross_signing_keys diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0a8571f0c8..b2f0aeaf58 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -27,7 +27,6 @@ from typing import ( Collection, Generator, Iterable, - Optional, Sequence, cast, ) @@ -129,7 +128,7 @@ class StateDifference: # The event IDs in the auth difference. auth_difference: set[str] # The event IDs in the conflicted state subgraph. Used in v2.1 only. - conflicted_subgraph: Optional[set[str]] + conflicted_subgraph: set[str] | None class _NoChainCoverIndex(Exception): @@ -142,7 +141,7 @@ class EventFederationWorkerStore( ): # TODO: this attribute comes from EventPushActionWorkerStore. Should we inherit from # that store so that mypy can deduce this for itself? - stream_ordering_month_ago: Optional[int] + stream_ordering_month_ago: int | None def __init__( self, @@ -494,8 +493,8 @@ class EventFederationWorkerStore( self, room_id: str, state_sets: list[set[str]], - conflicted_set: Optional[set[str]], - additional_backwards_reachable_conflicted_events: Optional[set[str]], + conflicted_set: set[str] | None, + additional_backwards_reachable_conflicted_events: set[str] | None, ) -> StateDifference: """ "Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -556,8 +555,8 @@ class EventFederationWorkerStore( txn: LoggingTransaction, room_id: str, state_sets: list[set[str]], - conflicted_set: Optional[set[str]] = None, - additional_backwards_reachable_conflicted_events: Optional[set[str]] = None, + conflicted_set: set[str] | None = None, + additional_backwards_reachable_conflicted_events: set[str] | None = None, ) -> StateDifference: """Calculates the auth chain difference using the chain index. @@ -1341,7 +1340,7 @@ class EventFederationWorkerStore( async def get_max_depth_of( self, event_ids: Collection[str] - ) -> tuple[Optional[str], int]: + ) -> tuple[str | None, int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: @@ -1373,7 +1372,7 @@ class EventFederationWorkerStore( return max_depth_event_id, current_max_depth - async def get_min_depth_of(self, event_ids: list[str]) -> tuple[Optional[str], int]: + async def get_min_depth_of(self, event_ids: list[str]) -> tuple[str | None, int]: """Returns the event ID and depth for the event that has the min depth from a set of event IDs Args: @@ -1491,7 +1490,7 @@ class EventFederationWorkerStore( ) return frozenset(event_ids) - async def get_min_depth(self, room_id: str) -> Optional[int]: + async def get_min_depth(self, room_id: str) -> int | None: """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id @@ -1499,7 +1498,7 @@ class EventFederationWorkerStore( def _get_min_depth_interaction( self, txn: LoggingTransaction, room_id: str - ) -> Optional[int]: + ) -> int | None: min_depth = self.db_pool.simple_select_one_onecol_txn( txn, table="room_depth", @@ -1689,7 +1688,7 @@ class EventFederationWorkerStore( ) events = await self.get_events_as_list(event_ids) return sorted( - # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering. + # type-ignore: mypy doesn't like negating the int | None stream_ordering. # But it's never None, because these events were previously persisted to the DB. events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator] @@ -2034,7 +2033,7 @@ class EventFederationWorkerStore( self, origin: str, event_id: str, - ) -> Optional[int]: + ) -> int | None: """Remove the given event from the staging area. Returns: @@ -2043,7 +2042,7 @@ class EventFederationWorkerStore( def _remove_received_event_from_staging_txn( txn: LoggingTransaction, - ) -> Optional[int]: + ) -> int | None: sql = """ DELETE FROM federation_inbound_events_staging WHERE origin = ? AND event_id = ? @@ -2051,7 +2050,7 @@ class EventFederationWorkerStore( """ txn.execute(sql, (origin, event_id)) - row = cast(Optional[tuple[int]], txn.fetchone()) + row = cast(tuple[int] | None, txn.fetchone()) if row is None: return None @@ -2067,7 +2066,7 @@ class EventFederationWorkerStore( async def get_next_staged_event_id_for_room( self, room_id: str, - ) -> Optional[tuple[str, str]]: + ) -> tuple[str, str] | None: """ Get the next event ID in the staging area for the given room. @@ -2077,7 +2076,7 @@ class EventFederationWorkerStore( def _get_next_staged_event_id_for_room_txn( txn: LoggingTransaction, - ) -> Optional[tuple[str, str]]: + ) -> tuple[str, str] | None: sql = """ SELECT origin, event_id FROM federation_inbound_events_staging @@ -2088,7 +2087,7 @@ class EventFederationWorkerStore( txn.execute(sql, (room_id,)) - return cast(Optional[tuple[str, str]], txn.fetchone()) + return cast(tuple[str, str] | None, txn.fetchone()) return await self.db_pool.runInteraction( "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn @@ -2098,12 +2097,12 @@ class EventFederationWorkerStore( self, room_id: str, room_version: RoomVersion, - ) -> Optional[tuple[str, EventBase]]: + ) -> tuple[str, EventBase] | None: """Get the next event in the staging area for the given room.""" def _get_next_staged_event_for_room_txn( txn: LoggingTransaction, - ) -> Optional[tuple[str, str, str]]: + ) -> tuple[str, str, str] | None: sql = """ SELECT event_json, internal_metadata, origin FROM federation_inbound_events_staging @@ -2113,7 +2112,7 @@ class EventFederationWorkerStore( """ txn.execute(sql, (room_id,)) - return cast(Optional[tuple[str, str, str]], txn.fetchone()) + return cast(tuple[str, str, str] | None, txn.fetchone()) row = await self.db_pool.runInteraction( "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn @@ -2258,7 +2257,7 @@ class EventFederationWorkerStore( "SELECT min(received_ts) FROM federation_inbound_events_staging" ) - (received_ts,) = cast(tuple[Optional[int]], txn.fetchone()) + (received_ts,) = cast(tuple[int | None], txn.fetchone()) # If there is nothing in the staging area default it to 0. age = 0 diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index d65ab82fff..2e99d7314e 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -85,8 +85,6 @@ from typing import ( TYPE_CHECKING, Collection, Mapping, - Optional, - Union, cast, ) @@ -115,11 +113,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -DEFAULT_NOTIF_ACTION: list[Union[dict, str]] = [ +DEFAULT_NOTIF_ACTION: list[dict | str] = [ "notify", {"set_tweak": "highlight", "value": False}, ] -DEFAULT_HIGHLIGHT_ACTION: list[Union[dict, str]] = [ +DEFAULT_HIGHLIGHT_ACTION: list[dict | str] = [ "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}, @@ -162,7 +160,7 @@ class HttpPushAction: event_id: str room_id: str stream_ordering: int - actions: list[Union[dict, str]] + actions: list[dict | str] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -172,7 +170,7 @@ class EmailPushAction(HttpPushAction): push notification. """ - received_ts: Optional[int] + received_ts: int | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -221,9 +219,7 @@ class RoomNotifCounts: _EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {}) -def _serialize_action( - actions: Collection[Union[Mapping, str]], is_highlight: bool -) -> str: +def _serialize_action(actions: Collection[Mapping | str], is_highlight: bool) -> str: """Custom serializer for actions. This allows us to "compress" common actions. We use the fact that most users have the same actions for notifs (and for @@ -241,7 +237,7 @@ def _serialize_action( return json_encoder.encode(actions) -def _deserialize_action(actions: str, is_highlight: bool) -> list[Union[dict, str]]: +def _deserialize_action(actions: str, is_highlight: bool) -> list[dict | str]: """Custom deserializer for actions. This allows us to "compress" common actions""" if actions: return db_to_json(actions) @@ -267,8 +263,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._started_ts = self.clock.time_msec() # These get correctly set by _find_stream_orderings_for_times_txn - self.stream_ordering_month_ago: Optional[int] = None - self.stream_ordering_day_ago: Optional[int] = None + self.stream_ordering_month_ago: int | None = None + self.stream_ordering_day_ago: int | None = None cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) @@ -773,8 +769,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas room_id: str, user_id: str, stream_ordering: int, - max_stream_ordering: Optional[int] = None, - thread_id: Optional[str] = None, + max_stream_ordering: int | None = None, + thread_id: str | None = None, ) -> list[tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -1156,7 +1152,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas async def add_push_actions_to_staging( self, event_id: str, - user_id_actions: dict[str, Collection[Union[Mapping, str]]], + user_id_actions: dict[str, Collection[Mapping | str]], count_as_unread: bool, thread_id: str, ) -> None: @@ -1175,7 +1171,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # This is a helper function for generating the necessary tuple that # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( - user_id: str, actions: Collection[Union[Mapping, str]] + user_id: str, actions: Collection[Mapping | str] ) -> tuple[str, str, str, int, int, int, str, int]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 @@ -1293,7 +1289,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The stream ordering """ txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = cast(tuple[Optional[int]], txn.fetchone())[0] + max_stream_ordering = cast(tuple[int | None], txn.fetchone())[0] if max_stream_ordering is None: return 0 @@ -1351,8 +1347,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas async def get_time_of_last_push_action_before( self, stream_ordering: int - ) -> Optional[int]: - def f(txn: LoggingTransaction) -> Optional[tuple[int]]: + ) -> int | None: + def f(txn: LoggingTransaction) -> tuple[int] | None: sql = """ SELECT e.received_ts FROM event_push_actions AS ep @@ -1362,7 +1358,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas LIMIT 1 """ txn.execute(sql, (stream_ordering,)) - return cast(Optional[tuple[int]], txn.fetchone()) + return cast(tuple[int] | None, txn.fetchone()) result = await self.db_pool.runInteraction( "get_time_of_last_push_action_before", f @@ -1454,7 +1450,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas limit, ), ) - rows = cast(list[tuple[int, str, str, Optional[str], int]], txn.fetchall()) + rows = cast(list[tuple[int, str, str, str | None, int]], txn.fetchall()) # For each new read receipt we delete push actions from before it and # recalculate the summary. @@ -1826,7 +1822,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas async def get_push_actions_for_user( self, user_id: str, - before: Optional[int] = None, + before: int | None = None, limit: int = 50, only_highlight: bool = False, ) -> list[UserPushAction]: @@ -1915,7 +1911,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) -def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool: +def _action_has_highlight(actions: Collection[Mapping | str]) -> bool: for action in actions: if not isinstance(action, dict): continue diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index da9ecfbdb9..59112e647c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -29,7 +29,6 @@ from typing import ( Collection, Generator, Iterable, - Optional, Sequence, TypedDict, cast, @@ -141,10 +140,10 @@ class SlidingSyncStateInsertValues(TypedDict, total=False): `sliding_sync_membership_snapshots` database tables. """ - room_type: Optional[str] - is_encrypted: Optional[bool] - room_name: Optional[str] - tombstone_successor_room_id: Optional[str] + room_type: str | None + is_encrypted: bool | None + room_name: str | None + tombstone_successor_room_id: str | None class SlidingSyncMembershipSnapshotSharedInsertValues( @@ -155,7 +154,7 @@ class SlidingSyncMembershipSnapshotSharedInsertValues( multiple memberships """ - has_known_state: Optional[bool] + has_known_state: bool | None @attr.s(slots=True, auto_attribs=True) @@ -193,7 +192,7 @@ class SlidingSyncTableChanges: # foreground update for # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by # https://github.com/element-hq/synapse/issues/17623) - joined_room_bump_stamp_to_fully_insert: Optional[int] + joined_room_bump_stamp_to_fully_insert: int | None # Values to upsert into `sliding_sync_joined_rooms` joined_room_updates: SlidingSyncStateInsertValues @@ -272,8 +271,8 @@ class PersistEventsStore: room_id: str, events_and_contexts: list[EventPersistencePair], *, - state_delta_for_room: Optional[DeltaState], - new_forward_extremities: Optional[set[str]], + state_delta_for_room: DeltaState | None, + new_forward_extremities: set[str] | None, new_event_links: dict[str, NewEventChainLinks], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, @@ -717,7 +716,7 @@ class PersistEventsStore: # `_update_sliding_sync_tables_with_new_persisted_events_txn()`) # joined_room_updates: SlidingSyncStateInsertValues = {} - bump_stamp_to_fully_insert: Optional[int] = None + bump_stamp_to_fully_insert: int | None = None if not delta_state.no_longer_in_room: current_state_ids_map = {} @@ -1014,10 +1013,10 @@ class PersistEventsStore: room_id: str, events_and_contexts: list[EventPersistencePair], inhibit_local_membership_updates: bool, - state_delta_for_room: Optional[DeltaState], - new_forward_extremities: Optional[set[str]], + state_delta_for_room: DeltaState | None, + new_forward_extremities: set[str] | None, new_event_links: dict[str, NewEventChainLinks], - sliding_sync_table_changes: Optional[SlidingSyncTableChanges], + sliding_sync_table_changes: SlidingSyncTableChanges | None, ) -> None: """Insert some number of room events into the necessary database tables. @@ -1570,7 +1569,7 @@ class PersistEventsStore: # existing_chains: set[int] = set() - tree: list[tuple[str, Optional[str]]] = [] + tree: list[tuple[str, str | None]] = [] # We need to do this in a topologically sorted order as we want to # generate chain IDs/sequence numbers of an event's auth events before @@ -1622,7 +1621,7 @@ class PersistEventsStore: if not existing_chain_id: existing_chain_id = chain_map[auth_event_id] - new_chain_tuple: Optional[tuple[Any, int]] = None + new_chain_tuple: tuple[Any, int] | None = None if existing_chain_id: # We found a chain ID/sequence number candidate, check its # not already taken. @@ -2491,7 +2490,7 @@ class PersistEventsStore: room_id: The room ID events_and_contexts: events we are persisting """ - stream_ordering: Optional[int] = None + stream_ordering: int | None = None depth_update = 0 for event, context in events_and_contexts: # Don't update the stream ordering for backfilled events because diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 637b9104c0..f8300e016b 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast import attr @@ -109,7 +109,7 @@ class _JoinedRoomStreamOrderingUpdate: # The most recent event stream_ordering for the room most_recent_event_stream_ordering: int # The most recent event `bump_stamp` for the room - most_recent_bump_stamp: Optional[int] + most_recent_bump_stamp: int | None class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseStore): @@ -1038,7 +1038,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS last_room_id: str, last_depth: int, last_stream: int, - batch_size: Optional[int], + batch_size: int | None, single_room: bool, ) -> _CalculateChainCover: """Calculate the chain cover for `batch_size` events, ordered by @@ -1889,14 +1889,14 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS ) -> list[ tuple[ str, - Optional[str], - Optional[str], + str | None, + str | None, str, str, str, str, int, - Optional[str], + str | None, bool, ] ]: @@ -1982,14 +1982,14 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS list[ tuple[ str, - Optional[str], - Optional[str], + str | None, + str | None, str, str, str, str, int, - Optional[str], + str | None, bool, ] ], @@ -2023,7 +2023,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS def _find_previous_invite_or_knock_membership_txn( txn: LoggingTransaction, room_id: str, user_id: str, event_id: str - ) -> Optional[tuple[str, str]]: + ) -> tuple[str, str] | None: # Find the previous invite/knock event before the leave event # # Here are some notes on how we landed on this query: @@ -2598,7 +2598,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS # Find the next room ID to process, with a relevant room version. room_ids: list[str] = [] - max_room_id: Optional[str] = None + max_room_id: str | None = None for room_id, room_version_str in txn: max_room_id = room_id @@ -2704,7 +2704,7 @@ def _resolve_stale_data_in_sliding_sync_joined_rooms_table( # If we have nothing written to the `sliding_sync_joined_rooms` table, there is # nothing to clean up - row = cast(Optional[tuple[int]], txn.fetchone()) + row = cast(tuple[int] | None, txn.fetchone()) max_stream_ordering_sliding_sync_joined_rooms_table = None depends_on = None if row is not None: @@ -2830,7 +2830,7 @@ def _resolve_stale_data_in_sliding_sync_membership_snapshots_table( # If we have nothing written to the `sliding_sync_membership_snapshots` table, # there is nothing to clean up - row = cast(Optional[tuple[int]], txn.fetchone()) + row = cast(tuple[int] | None, txn.fetchone()) max_stream_ordering_sliding_sync_membership_snapshots_table = None if row is not None: (max_stream_ordering_sliding_sync_membership_snapshots_table,) = row diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index d43fb443fd..9908244dbf 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -20,7 +20,7 @@ # import logging -from typing import Optional, cast +from typing import cast from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction @@ -98,7 +98,7 @@ class EventForwardExtremitiesStore( async def get_forward_extremities_for_room( self, room_id: str - ) -> list[tuple[str, int, int, Optional[int]]]: + ) -> list[tuple[str, int, int, int | None]]: """ Get list of forward extremities for a room. @@ -108,7 +108,7 @@ class EventForwardExtremitiesStore( def get_forward_extremities_for_room_txn( txn: LoggingTransaction, - ) -> list[tuple[str, int, int, Optional[int]]]: + ) -> list[tuple[str, int, int, int | None]]: sql = """ SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities @@ -118,7 +118,7 @@ class EventForwardExtremitiesStore( """ txn.execute(sql, (room_id,)) - return cast(list[tuple[str, int, int, Optional[int]]], txn.fetchall()) + return cast(list[tuple[str, int, int, int | None]], txn.fetchall()) return await self.db_pool.runInteraction( "get_forward_extremities_for_room", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 005f75a2d8..29bc1b982a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -31,7 +31,6 @@ from typing import ( Literal, Mapping, MutableMapping, - Optional, cast, overload, ) @@ -146,7 +145,7 @@ class InvalidEventError(Exception): @attr.s(slots=True, auto_attribs=True) class EventCacheEntry: event: EventBase - redacted_event: Optional[EventBase] + redacted_event: EventBase | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -184,9 +183,9 @@ class _EventRow: instance_name: str json: str internal_metadata: str - format_version: Optional[int] - room_version_id: Optional[str] - rejected_reason: Optional[str] + format_version: int | None + room_version_id: str | None + rejected_reason: str | None redactions: list[str] outlier: bool @@ -501,7 +500,7 @@ class EventsWorkerStore(SQLBaseStore): get_prev_content: bool = ..., allow_rejected: bool = ..., allow_none: Literal[False] = ..., - check_room_id: Optional[str] = ..., + check_room_id: str | None = ..., ) -> EventBase: ... @overload @@ -512,8 +511,8 @@ class EventsWorkerStore(SQLBaseStore): get_prev_content: bool = ..., allow_rejected: bool = ..., allow_none: Literal[True] = ..., - check_room_id: Optional[str] = ..., - ) -> Optional[EventBase]: ... + check_room_id: str | None = ..., + ) -> EventBase | None: ... @cancellable async def get_event( @@ -523,8 +522,8 @@ class EventsWorkerStore(SQLBaseStore): get_prev_content: bool = False, allow_rejected: bool = False, allow_none: bool = False, - check_room_id: Optional[str] = None, - ) -> Optional[EventBase]: + check_room_id: str | None = None, + ) -> EventBase | None: """Get an event from the database by event_id. Events for unknown room versions will also be filtered out. @@ -1090,7 +1089,7 @@ class EventsWorkerStore(SQLBaseStore): self, context: EventContext, state_keys_to_include: StateFilter, - membership_user_id: Optional[str] = None, + membership_user_id: str | None = None, ) -> list[JsonDict]: """ Retrieve the stripped state from a room, given an event context to retrieve state @@ -1403,7 +1402,7 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row.room_version_id - room_version: Optional[RoomVersion] + room_version: RoomVersion | None if not room_version_id: # this should only happen for out-of-band membership events which # arrived before https://github.com/matrix-org/synapse/issues/6983 @@ -1653,7 +1652,7 @@ class EventsWorkerStore(SQLBaseStore): original_ev: EventBase, redactions: Iterable[str], event_map: dict[str, EventBase], - ) -> Optional[EventBase]: + ) -> EventBase | None: """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. @@ -2131,7 +2130,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_senders_for_event_ids( self, event_ids: Collection[str] - ) -> dict[str, Optional[str]]: + ) -> dict[str, str | None]: """ Given a sequence of event IDs, return the sender associated with each. @@ -2147,7 +2146,7 @@ class EventsWorkerStore(SQLBaseStore): def _get_senders_for_event_ids( txn: LoggingTransaction, - ) -> dict[str, Optional[str]]: + ) -> dict[str, str | None]: rows = self.db_pool.simple_select_many_txn( txn=txn, table="events", @@ -2178,7 +2177,7 @@ class EventsWorkerStore(SQLBaseStore): return int(res[0]), int(res[1]) - async def get_next_event_to_expire(self) -> Optional[tuple[str, int]]: + async def get_next_event_to_expire(self) -> tuple[str, int] | None: """Retrieve the entry with the lowest expiry timestamp in the event_expiry table, or None if there's no more event to expire. @@ -2190,7 +2189,7 @@ class EventsWorkerStore(SQLBaseStore): def get_next_event_to_expire_txn( txn: LoggingTransaction, - ) -> Optional[tuple[str, int]]: + ) -> tuple[str, int] | None: txn.execute( """ SELECT event_id, expiry_ts FROM event_expiry @@ -2198,7 +2197,7 @@ class EventsWorkerStore(SQLBaseStore): """ ) - return cast(Optional[tuple[str, int]], txn.fetchone()) + return cast(tuple[str, int] | None, txn.fetchone()) return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn @@ -2206,7 +2205,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_event_id_from_transaction_id_and_device_id( self, room_id: str, user_id: str, device_id: str, txn_id: str - ) -> Optional[str]: + ) -> str | None: """Look up if we have already persisted an event for the transaction ID, returning the event ID if so. """ @@ -2427,7 +2426,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_event_id_for_timestamp( self, room_id: str, timestamp: int, direction: Direction - ) -> Optional[str]: + ) -> str | None: """Find the closest event to the given timestamp in the given direction. Args: @@ -2481,7 +2480,7 @@ class EventsWorkerStore(SQLBaseStore): LIMIT 1; """ - def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: + def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> str | None: txn.execute( sql_template, (room_id, timestamp), @@ -2591,7 +2590,7 @@ class EventsWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, event_id: str, - rejection_reason: Optional[str], + rejection_reason: str | None, ) -> None: """Mark an event that was previously accepted as rejected, or vice versa @@ -2640,8 +2639,8 @@ class EventsWorkerStore(SQLBaseStore): self.invalidate_get_event_cache_after_txn(txn, event_id) async def get_events_sent_by_user_in_room( - self, user_id: str, room_id: str, limit: int, filter: Optional[list[str]] = None - ) -> Optional[list[str]]: + self, user_id: str, room_id: str, limit: int, filter: list[str] | None = None + ) -> list[str] | None: """ Get a list of event ids of events sent by the user in the specified room @@ -2656,10 +2655,10 @@ class EventsWorkerStore(SQLBaseStore): txn: LoggingTransaction, user_id: str, room_id: str, - filter: Optional[list[str]], + filter: list[str] | None, batch_size: int, offset: int, - ) -> tuple[Optional[list[str]], int]: + ) -> tuple[list[str] | None, int]: if filter: base_clause, args = make_in_list_sql_clause( txn.database_engine, "type", filter @@ -2767,7 +2766,7 @@ class EventsWorkerStore(SQLBaseStore): @cached(tree=True) async def get_metadata_for_event( self, room_id: str, event_id: str - ) -> Optional[EventMetadata]: + ) -> EventMetadata | None: row = await self.db_pool.simple_select_one( table="events", keyvalues={"room_id": room_id, "event_id": event_id}, diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 4b3bc69d20..2019ad9904 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -20,7 +20,7 @@ # # -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, cast from canonicaljson import encode_canonical_json @@ -72,7 +72,7 @@ class FilteringWorkerStore(SQLBaseStore): lower_bound_id = progress.get("lower_bound_id", "") - def _get_last_id(txn: LoggingTransaction) -> Optional[str]: + def _get_last_id(txn: LoggingTransaction) -> str | None: sql = """ SELECT user_id FROM user_filters WHERE user_id > ? @@ -151,7 +151,7 @@ class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( - self, user_id: UserID, filter_id: Union[int, str] + self, user_id: UserID, filter_id: int | str ) -> JsonMapping: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. @@ -187,7 +187,7 @@ class FilteringWorkerStore(SQLBaseStore): sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?" txn.execute(sql, (user_id.to_string(),)) - max_id = cast(tuple[Optional[int]], txn.fetchone())[0] + max_id = cast(tuple[int | None], txn.fetchone())[0] if max_id is None: filter_id = 0 else: diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 9833565095..f81257b5a1 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -22,7 +22,7 @@ import itertools import json import logging -from typing import Iterable, Mapping, Optional, Union, cast +from typing import Iterable, Mapping, cast from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -201,7 +201,7 @@ class KeyStore(CacheInvalidationWorkerStore): self, server_name: str, key_id: str, - ) -> Optional[FetchKeyResultForRemote]: + ) -> FetchKeyResultForRemote | None: raise NotImplementedError() @cachedList( @@ -209,13 +209,13 @@ class KeyStore(CacheInvalidationWorkerStore): ) async def get_server_keys_json_for_remote( self, server_name: str, key_ids: Iterable[str] - ) -> Mapping[str, Optional[FetchKeyResultForRemote]]: + ) -> Mapping[str, FetchKeyResultForRemote | None]: """Fetch the cached keys for the given server/key IDs. If we have multiple entries for a given key ID, returns the most recent. """ rows = cast( - list[tuple[str, str, int, int, Union[bytes, memoryview]]], + list[tuple[str, str, int, int, bytes | memoryview]], await self.db_pool.simple_select_many_batch( table="server_keys_json", column="key_id", @@ -258,7 +258,7 @@ class KeyStore(CacheInvalidationWorkerStore): If we have multiple entries for a given key ID, returns the most recent. """ rows = cast( - list[tuple[str, str, int, int, Union[bytes, memoryview]]], + list[tuple[str, str, int, int, bytes | memoryview]], await self.db_pool.simple_select_list( table="server_keys_json", keyvalues={"server_name": server_name}, diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 9dd2cae344..51f04acbcb 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -290,7 +290,7 @@ class LockStore(SQLBaseStore): self, lock_names: Collection[tuple[str, str]], write: bool, - ) -> Optional[AsyncExitStack]: + ) -> AsyncExitStack | None: """Try to acquire multiple locks for the given names/keys. Will return an async context manager if the locks are successfully acquired, which *must* be used (otherwise the lock will leak). @@ -402,7 +402,7 @@ class Lock: # We might be called from a non-main thread, so we defer setting up the # looping call. - self._looping_call: Optional[LoopingCall] = None + self._looping_call: LoopingCall | None = None reactor.callFromThread(self._setup_looping_call) self._dropped = False @@ -497,9 +497,9 @@ class Lock: async def __aexit__( self, - _exctype: Optional[type[BaseException]], - _excinst: Optional[BaseException], - _exctb: Optional[TracebackType], + _exctype: type[BaseException] | None, + _excinst: BaseException | None, + _exctb: TracebackType | None, ) -> bool: await self.release() diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index b9f882662e..50664d63e5 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -25,8 +25,6 @@ from typing import ( TYPE_CHECKING, Collection, Iterable, - Optional, - Union, cast, ) @@ -57,16 +55,16 @@ logger = logging.getLogger(__name__) class LocalMedia: media_id: str media_type: str - media_length: Optional[int] + media_length: int | None upload_name: str created_ts: int - url_cache: Optional[str] + url_cache: str | None last_access_ts: int - quarantined_by: Optional[str] + quarantined_by: str | None safe_from_quarantine: bool - user_id: Optional[str] - authenticated: Optional[bool] - sha256: Optional[str] + user_id: str | None + authenticated: bool | None + sha256: str | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -75,20 +73,20 @@ class RemoteMedia: media_id: str media_type: str media_length: int - upload_name: Optional[str] + upload_name: str | None filesystem_id: str created_ts: int last_access_ts: int - quarantined_by: Optional[str] - authenticated: Optional[bool] - sha256: Optional[str] + quarantined_by: str | None + authenticated: bool | None + sha256: str | None @attr.s(slots=True, frozen=True, auto_attribs=True) class UrlCache: response_code: int expires_ts: int - og: Union[str, bytes] + og: str | bytes class MediaSortOrder(Enum): @@ -183,7 +181,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): ) if hs.config.media.can_load_media_repo: - self.unused_expiration_time: Optional[int] = ( + self.unused_expiration_time: int | None = ( hs.config.media.unused_expiration_time ) else: @@ -224,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): super().__init__(database, db_conn, hs) self.server_name: str = hs.hostname - async def get_local_media(self, media_id: str) -> Optional[LocalMedia]: + async def get_local_media(self, media_id: str) -> LocalMedia | None: """Get the metadata for a local piece of media Returns: @@ -299,7 +297,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): else: order = "ASC" - args: list[Union[str, int]] = [user_id] + args: list[str | int] = [user_id] sql = """ SELECT COUNT(*) as total_media FROM local_media_repository @@ -472,12 +470,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_id: str, media_type: str, time_now_ms: int, - upload_name: Optional[str], + upload_name: str | None, media_length: int, user_id: UserID, - url_cache: Optional[str] = None, - sha256: Optional[str] = None, - quarantined_by: Optional[str] = None, + url_cache: str | None = None, + sha256: str | None = None, + quarantined_by: str | None = None, ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -505,12 +503,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): self, media_id: str, media_type: str, - upload_name: Optional[str], + upload_name: str | None, media_length: int, user_id: UserID, sha256: str, - url_cache: Optional[str] = None, - quarantined_by: Optional[str] = None, + url_cache: str | None = None, + quarantined_by: str | None = None, ) -> None: updatevalues = { "media_type": media_type, @@ -575,13 +573,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_pending_media", get_pending_media_txn ) - async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: + async def get_url_cache(self, url: str, ts: int) -> UrlCache | None: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. """ - def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]: + def get_url_cache_txn(txn: LoggingTransaction) -> UrlCache | None: # get the most recently cached result (relative to the given ts) sql = """ SELECT response_code, expires_ts, og @@ -615,7 +613,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): self, url: str, response_code: int, - etag: Optional[str], + etag: str | None, expires_ts: int, og: str, media_id: str, @@ -683,7 +681,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_cached_remote_media( self, origin: str, media_id: str - ) -> Optional[RemoteMedia]: + ) -> RemoteMedia | None: row = await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, @@ -724,9 +722,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_type: str, media_length: int, time_now_ms: int, - upload_name: Optional[str], + upload_name: str | None, filesystem_id: str, - sha256: Optional[str], + sha256: str | None, ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -822,7 +820,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): t_width: int, t_height: int, t_type: str, - ) -> Optional[ThumbnailInfo]: + ) -> ThumbnailInfo | None: """Fetch the thumbnail info of given width, height and type.""" row = await self.db_pool.simple_select_one( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index bf8e540ffb..b51edd5d0c 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -18,7 +18,7 @@ # # import logging -from typing import TYPE_CHECKING, Mapping, Optional, cast +from typing import TYPE_CHECKING, Mapping, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -129,7 +129,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): ) async def get_monthly_active_users_by_service( - self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None + self, start_timestamp: int | None = None, end_timestamp: int | None = None ) -> list[tuple[str, str]]: """Generates list of monthly active users and their services. Please see "get_monthly_active_count_by_service" docstring for more details @@ -194,7 +194,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): return users @cached(num_args=1) - async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]: + async def user_last_seen_monthly_active(self, user_id: str) -> int | None: """ Checks if a given user is part of the monthly active user group diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 0db7f73730..15c47a2562 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -19,7 +19,6 @@ # # -from typing import Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction @@ -41,8 +40,8 @@ class OpenIdStore(SQLBaseStore): async def get_user_id_for_open_id_token( self, token: str, ts_now_ms: int - ) -> Optional[str]: - def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]: + ) -> str | None: + def get_user_id_for_token_txn(txn: LoggingTransaction) -> str | None: sql = ( "SELECT user_id FROM open_id_tokens" " WHERE token = ? AND ? <= ts_valid_until_ms" diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index fec94f4e5a..75ca9e40d7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,8 +23,6 @@ from typing import ( Any, Iterable, Mapping, - Optional, - Union, cast, ) @@ -260,7 +258,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) # TODO All these columns are nullable, but we don't expect that: # https://github.com/matrix-org/synapse/issues/16467 rows = cast( - list[tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], + list[tuple[str, str, int, int, int, str | None, int | bool]], await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", @@ -317,7 +315,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) @cached() async def _get_full_presence_stream_token_for_user( self, user_id: str - ) -> Optional[int]: + ) -> int | None: """Get the presence token corresponding to the last full presence update for this user. @@ -399,7 +397,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) # TODO All these columns are nullable, but we don't expect that: # https://github.com/matrix-org/synapse/issues/16467 rows = cast( - list[tuple[str, str, int, int, int, Optional[str], Union[int, bool]]], + list[tuple[str, str, int, int, int, str | None, int | bool]], await self.db_pool.runInteraction( "get_presence_for_all_users", self.db_pool.simple_select_list_paginate_txn, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 71f01a597b..11ad516eb3 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -19,7 +19,7 @@ # # import json -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from canonicaljson import encode_canonical_json @@ -75,7 +75,7 @@ class ProfileWorkerStore(SQLBaseStore): lower_bound_id = progress.get("lower_bound_id", "") - def _get_last_id(txn: LoggingTransaction) -> Optional[str]: + def _get_last_id(txn: LoggingTransaction) -> str | None: sql = """ SELECT user_id FROM profiles WHERE user_id > ? @@ -176,7 +176,7 @@ class ProfileWorkerStore(SQLBaseStore): return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) - async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: + async def get_profile_displayname(self, user_id: UserID) -> str | None: """ Fetch the display name of a user. @@ -193,7 +193,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) - async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]: + async def get_profile_avatar_url(self, user_id: UserID) -> str | None: """ Fetch the avatar URL of a user. @@ -257,9 +257,7 @@ class ProfileWorkerStore(SQLBaseStore): ) # If value_type is None, then the value did not exist. - value_type, value = cast( - tuple[Optional[str], JsonValue], txn.fetchone() - ) + value_type, value = cast(tuple[str | None, JsonValue], txn.fetchone()) if not value_type: raise StoreError(404, "No row found") # If value_type is object or array, then need to deserialize the JSON. @@ -346,7 +344,7 @@ class ProfileWorkerStore(SQLBaseStore): # possible due to the grammar. (f'$."{new_field_name}"', user_id.localpart), ) - row = cast(tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone()) + row = cast(tuple[int | None, int | None, int | None], txn.fetchone()) # The values return null if the column is null. total_bytes = ( @@ -373,7 +371,7 @@ class ProfileWorkerStore(SQLBaseStore): raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE) async def set_profile_displayname( - self, user_id: UserID, new_displayname: Optional[str] + self, user_id: UserID, new_displayname: str | None ) -> None: """ Set the display name of a user. @@ -406,7 +404,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_avatar_url( - self, user_id: UserID, new_avatar_url: Optional[str] + self, user_id: UserID, new_avatar_url: str | None ) -> None: """ Set the avatar of a user. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ecab19eb2e..d361166cec 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -25,9 +25,7 @@ from typing import ( Collection, Iterable, Mapping, - Optional, Sequence, - Union, cast, ) @@ -231,7 +229,7 @@ class PushRulesWorkerStore( async def get_push_rules_enabled_for_user(self, user_id: str) -> dict[str, bool]: results = cast( - list[tuple[str, Optional[Union[int, bool]]]], + list[tuple[str, int | bool | None]], await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, @@ -327,7 +325,7 @@ class PushRulesWorkerStore( results: dict[str, dict[str, bool]] = {user_id: {} for user_id in user_ids} rows = cast( - list[tuple[str, str, Optional[int]]], + list[tuple[str, str, int | None]], await self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", @@ -402,9 +400,9 @@ class PushRulesWorkerStore( rule_id: str, priority_class: int, conditions: Sequence[Mapping[str, str]], - actions: Sequence[Union[Mapping[str, Any], str]], - before: Optional[str] = None, - after: Optional[str] = None, + actions: Sequence[Mapping[str, Any] | str], + before: str | None = None, + after: str | None = None, ) -> None: if not self._is_push_writer: raise Exception("Not a push writer") @@ -791,7 +789,7 @@ class PushRulesWorkerStore( self, user_id: str, rule_id: str, - actions: list[Union[dict, str]], + actions: list[dict | str], is_default_rule: bool, ) -> None: """ @@ -882,7 +880,7 @@ class PushRulesWorkerStore( user_id: str, rule_id: str, op: str, - data: Optional[JsonDict] = None, + data: JsonDict | None = None, ) -> None: if not self._is_push_writer: raise Exception("Not a push writer") diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c8f049536a..e7ab7f64f9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -25,7 +25,6 @@ from typing import ( Any, Iterable, Iterator, - Optional, cast, ) @@ -51,7 +50,7 @@ logger = logging.getLogger(__name__) PusherRow = tuple[ int, # id str, # user_name - Optional[int], # access_token + int | None, # access_token str, # profile_tag str, # kind str, # app_id @@ -365,7 +364,7 @@ class PusherWorkerStore(SQLBaseStore): return bool(updated) async def update_pusher_failing_since( - self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int] + self, app_id: str, pushkey: str, user_id: str, failing_since: int | None ) -> None: await self.db_pool.simple_update( table="pushers", @@ -378,7 +377,7 @@ class PusherWorkerStore(SQLBaseStore): self, pusher_id: int ) -> dict[str, ThrottleParams]: res = cast( - list[tuple[str, Optional[int], Optional[int]]], + list[tuple[str, int | None, int | None]], await self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, @@ -607,7 +606,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): (last_pusher_id, batch_size), ) - rows = cast(list[tuple[int, Optional[str], Optional[str]]], txn.fetchall()) + rows = cast(list[tuple[int, str | None, str | None]], txn.fetchall()) if len(rows) == 0: return 0 @@ -666,13 +665,13 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): device_display_name: str, pushkey: str, pushkey_ts: int, - lang: Optional[str], - data: Optional[JsonDict], + lang: str | None, + data: JsonDict | None, last_stream_ordering: int, profile_tag: str = "", enabled: bool = True, - device_id: Optional[str] = None, - access_token_id: Optional[int] = None, + device_id: str | None = None, + access_token_id: int | None = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.simple_upsert( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 63d4e1f68c..ba5e07a051 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -26,7 +26,6 @@ from typing import ( Collection, Iterable, Mapping, - Optional, Sequence, cast, ) @@ -67,7 +66,7 @@ class ReceiptInRoom: receipt_type: str user_id: str event_id: str - thread_id: Optional[str] + thread_id: str | None data: JsonMapping @staticmethod @@ -176,7 +175,7 @@ class ReceiptsWorkerStore(SQLBaseStore): user_id: str, room_id: str, receipt_types: Collection[str], - ) -> Optional[tuple[str, int]]: + ) -> tuple[str, int] | None: """ Fetch the event ID and stream_ordering for the latest unthreaded receipt in a room with one of the given receipt types. @@ -208,7 +207,7 @@ class ReceiptsWorkerStore(SQLBaseStore): args.extend((user_id, room_id)) txn.execute(sql, args) - return cast(Optional[tuple[str, int]], txn.fetchone()) + return cast(tuple[str, int] | None, txn.fetchone()) async def get_receipts_for_user( self, user_id: str, receipt_types: Iterable[str] @@ -311,7 +310,7 @@ class ReceiptsWorkerStore(SQLBaseStore): self, room_ids: Iterable[str], to_key: MultiWriterStreamToken, - from_key: Optional[MultiWriterStreamToken] = None, + from_key: MultiWriterStreamToken | None = None, ) -> list[JsonMapping]: """Get receipts for multiple rooms for sending to clients. @@ -343,7 +342,7 @@ class ReceiptsWorkerStore(SQLBaseStore): self, room_id: str, to_key: MultiWriterStreamToken, - from_key: Optional[MultiWriterStreamToken] = None, + from_key: MultiWriterStreamToken | None = None, ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. @@ -371,7 +370,7 @@ class ReceiptsWorkerStore(SQLBaseStore): self, room_id: str, to_key: MultiWriterStreamToken, - from_key: Optional[MultiWriterStreamToken] = None, + from_key: MultiWriterStreamToken | None = None, ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" @@ -425,7 +424,7 @@ class ReceiptsWorkerStore(SQLBaseStore): self, room_ids: Collection[str], to_key: MultiWriterStreamToken, - from_key: Optional[MultiWriterStreamToken] = None, + from_key: MultiWriterStreamToken | None = None, ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} @@ -528,7 +527,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_linearized_receipts_for_events_txn( txn: LoggingTransaction, room_id_event_id_tuples: Collection[tuple[str, str]], - ) -> list[tuple[str, str, str, str, Optional[str], str]]: + ) -> list[tuple[str, str, str, str, str | None, str]]: clause, args = make_tuple_in_list_sql_clause( self.database_engine, ("room_id", "event_id"), room_id_event_id_tuples ) @@ -578,7 +577,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_all_rooms( self, to_key: MultiWriterStreamToken, - from_key: Optional[MultiWriterStreamToken] = None, + from_key: MultiWriterStreamToken | None = None, ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. @@ -655,7 +654,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_linearized_receipts_for_user_in_rooms_txn( txn: LoggingTransaction, batch_room_ids: StrCollection, - ) -> list[tuple[str, str, str, str, Optional[str], str]]: + ) -> list[tuple[str, str, str, str, str | None, str]]: clause, args = make_in_list_sql_clause( self.database_engine, "room_id", batch_room_ids ) @@ -780,7 +779,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> tuple[ - list[tuple[int, tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool + list[tuple[int, tuple[str, str, str, str, str | None, JsonDict]]], int, bool ]: """Get updates for receipts replication stream. @@ -809,7 +808,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_all_updated_receipts_txn( txn: LoggingTransaction, ) -> tuple[ - list[tuple[int, tuple[str, str, str, str, Optional[str], JsonDict]]], + list[tuple[int, tuple[str, str, str, str, str | None, JsonDict]]], int, bool, ]: @@ -824,7 +823,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, instance_name, limit)) updates = cast( - list[tuple[int, tuple[str, str, str, str, Optional[str], JsonDict]]], + list[tuple[int, tuple[str, str, str, str, str | None, JsonDict]]], [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn], ) @@ -884,10 +883,10 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_id: str, - thread_id: Optional[str], + thread_id: str | None, data: JsonDict, stream_id: int, - ) -> Optional[int]: + ) -> int | None: """Inserts a receipt into the database if it's newer than the current one. Returns: @@ -1023,9 +1022,9 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: list[str], - thread_id: Optional[str], + thread_id: str | None, data: dict, - ) -> Optional[PersistedPosition]: + ) -> PersistedPosition | None: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph @@ -1095,7 +1094,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: list[str], - thread_id: Optional[str], + thread_id: str | None, data: JsonDict, ) -> None: assert self._can_write_to_receipts diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index bad2d0b63a..545b0f11c4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -22,7 +22,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast import attr @@ -94,8 +94,8 @@ class TokenLookupResult: token_id: int is_guest: bool = False shadow_banned: bool = False - device_id: Optional[str] = None - valid_until_ms: Optional[int] = None + device_id: str | None = None + valid_until_ms: int | None = None token_owner: str = attr.ib() token_used: bool = False @@ -118,7 +118,7 @@ class RefreshTokenLookupResult: token_id: int """The ID of this refresh token.""" - next_token_id: Optional[int] + next_token_id: int | None """The ID of the refresh token which replaced this one.""" has_next_refresh_token_been_refreshed: bool @@ -127,11 +127,11 @@ class RefreshTokenLookupResult: has_next_access_token_been_used: bool """True if the next access token was already used at least once.""" - expiry_ts: Optional[int] + expiry_ts: int | None """The time at which the refresh token expires and can not be used. If None, the refresh token doesn't expire.""" - ultimate_session_expiry_ts: Optional[int] + ultimate_session_expiry_ts: int | None """The time at which the session comes to an end and can no longer be refreshed. If None, the session can be refreshed indefinitely.""" @@ -144,10 +144,10 @@ class LoginTokenLookupResult: user_id: str """The user this token belongs to.""" - auth_provider_id: Optional[str] + auth_provider_id: str | None """The SSO Identity Provider that the user authenticated with, to get this token.""" - auth_provider_session_id: Optional[str] + auth_provider_session_id: str | None """The session ID advertised by the SSO Identity Provider.""" @@ -171,7 +171,7 @@ class ThreepidValidationSession: """ID of the validation session""" last_send_attempt: int """a number serving to dedupe send attempts for this session""" - validated_at: Optional[int] + validated_at: int | None """timestamp of when this session was validated if so""" @@ -233,13 +233,13 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): async def register_user( self, user_id: str, - password_hash: Optional[str] = None, + password_hash: str | None = None, was_guest: bool = False, make_guest: bool = False, - appservice_id: Optional[str] = None, - create_profile_with_displayname: Optional[str] = None, + appservice_id: str | None = None, + create_profile_with_displayname: str | None = None, admin: bool = False, - user_type: Optional[str] = None, + user_type: str | None = None, shadow_banned: bool = False, approved: bool = False, ) -> None: @@ -286,13 +286,13 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): self, txn: LoggingTransaction, user_id: str, - password_hash: Optional[str], + password_hash: str | None, was_guest: bool, make_guest: bool, - appservice_id: Optional[str], - create_profile_with_displayname: Optional[str], + appservice_id: str | None, + create_profile_with_displayname: str | None, admin: bool, - user_type: Optional[str], + user_type: str | None, shadow_banned: bool, approved: bool, ) -> None: @@ -379,10 +379,10 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: + async def get_user_by_id(self, user_id: str) -> UserInfo | None: """Returns info about the user account, if it exists.""" - def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]: + def get_user_by_id_txn(txn: LoggingTransaction) -> UserInfo | None: # We could technically use simple_select_one here, but it would not perform # the COALESCEs (unless hacked into the column names), which could yield # confusing results. @@ -466,7 +466,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): return is_trial @cached() - async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]: + async def get_user_by_access_token(self, token: str) -> TokenLookupResult | None: """Get a user from the given access token. Args: @@ -479,7 +479,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): ) @cached() - async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]: + async def get_expiration_ts_for_user(self, user_id: str) -> int | None: """Get the expiration timestamp for the account bearing a given user ID. Args: @@ -515,8 +515,8 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): user_id: str, expiration_ts: int, email_sent: bool, - renewal_token: Optional[str] = None, - token_used_ts: Optional[int] = None, + renewal_token: str | None = None, + token_used_ts: int | None = None, ) -> None: """Updates the account validity properties of the given account, with the given values. @@ -576,7 +576,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): async def get_user_from_renewal_token( self, renewal_token: str - ) -> tuple[str, int, Optional[int]]: + ) -> tuple[str, int, int | None]: """Get a user ID and renewal status from a renewal token. Args: @@ -592,7 +592,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): has not been renewed using the current token yet. """ return cast( - tuple[str, int, Optional[int]], + tuple[str, int, int | None], await self.db_pool.simple_select_one( table="account_validity", keyvalues={"renewal_token": renewal_token}, @@ -745,7 +745,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) async def set_user_type( - self, user: UserID, user_type: Optional[Union[UserTypes, str]] + self, user: UserID, user_type: UserTypes | str | None ) -> None: """Sets the user type. @@ -766,7 +766,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): def _query_for_auth( self, txn: LoggingTransaction, token: str - ) -> Optional[TokenLookupResult]: + ) -> TokenLookupResult | None: sql = """ SELECT users.name as user_id, users.is_guest, @@ -1027,7 +1027,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): @cached() async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> Optional[str]: + ) -> str | None: """Look up a user by their external auth id Args: @@ -1145,7 +1145,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): return str(next_id) - async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: + async def get_user_id_by_threepid(self, medium: str, address: str) -> str | None: """Returns user id from threepid Args: @@ -1163,7 +1163,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): def get_user_id_by_threepid_txn( self, txn: LoggingTransaction, medium: str, address: str - ) -> Optional[str]: + ) -> str | None: """Returns user id from threepid Args: @@ -1386,12 +1386,12 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): async def get_threepid_validation_session( self, - medium: Optional[str], + medium: str | None, client_secret: str, - address: Optional[str] = None, - sid: Optional[str] = None, - validated: Optional[bool] = True, - ) -> Optional[ThreepidValidationSession]: + address: str | None = None, + sid: str | None = None, + validated: bool | None = True, + ) -> ThreepidValidationSession | None: """Gets a session_id and last_send_attempt (if available) for a combination of validation metadata @@ -1425,7 +1425,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): def get_threepid_validation_session_txn( txn: LoggingTransaction, - ) -> Optional[ThreepidValidationSession]: + ) -> ThreepidValidationSession | None: sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at @@ -1555,7 +1555,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) - async def get_user_pending_deactivation(self) -> Optional[str]: + async def get_user_pending_deactivation(self) -> str | None: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. @@ -1686,7 +1686,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): """ def _use_registration_token_txn(txn: LoggingTransaction) -> None: - # Normally, res is Optional[dict[str, Any]]. + # Normally, res is dict[str, Any] | None. # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. @@ -1715,8 +1715,8 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): ) async def get_registration_tokens( - self, valid: Optional[bool] = None - ) -> list[tuple[str, Optional[int], int, int, Optional[int]]]: + self, valid: bool | None = None + ) -> list[tuple[str, int | None, int, int, int | None]]: """List all registration tokens. Used by the admin API. Args: @@ -1734,8 +1734,8 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): """ def select_registration_tokens_txn( - txn: LoggingTransaction, now: int, valid: Optional[bool] - ) -> list[tuple[str, Optional[int], int, int, Optional[int]]]: + txn: LoggingTransaction, now: int, valid: bool | None + ) -> list[tuple[str, int | None, int, int, int | None]]: if valid is None: # Return all tokens regardless of validity txn.execute( @@ -1765,7 +1765,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): txn.execute(sql, [now]) return cast( - list[tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall() + list[tuple[str, int | None, int, int, int | None]], txn.fetchall() ) return await self.db_pool.runInteraction( @@ -1775,7 +1775,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): valid, ) - async def get_one_registration_token(self, token: str) -> Optional[dict[str, Any]]: + async def get_one_registration_token(self, token: str) -> dict[str, Any] | None: """Get info about the given registration token. Used by the admin API. Args: @@ -1801,9 +1801,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): "expiry_time": row[4], } - async def generate_registration_token( - self, length: int, chars: str - ) -> Optional[str]: + async def generate_registration_token(self, length: int, chars: str) -> str | None: """Generate a random registration token. Used by the admin API. Args: @@ -1843,7 +1841,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): ) async def create_registration_token( - self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] + self, token: str, uses_allowed: int | None, expiry_time: int | None ) -> bool: """Create a new registration token. Used by the admin API. @@ -1892,8 +1890,8 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): ) async def update_registration_token( - self, token: str, updatevalues: dict[str, Optional[int]] - ) -> Optional[dict[str, Any]]: + self, token: str, updatevalues: dict[str, int | None] + ) -> dict[str, Any] | None: """Update a registration token. Used by the admin API. Args: @@ -1909,7 +1907,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): def _update_registration_token_txn( txn: LoggingTransaction, - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: try: self.db_pool.simple_update_one_txn( txn, @@ -1996,14 +1994,12 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): desc="mark_access_token_as_used", ) - async def lookup_refresh_token( - self, token: str - ) -> Optional[RefreshTokenLookupResult]: + async def lookup_refresh_token(self, token: str) -> RefreshTokenLookupResult | None: """Lookup a refresh token with hints about its validity.""" def _lookup_refresh_token_txn( txn: LoggingTransaction, - ) -> Optional[RefreshTokenLookupResult]: + ) -> RefreshTokenLookupResult | None: txn.execute( """ SELECT @@ -2154,8 +2150,8 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): user_id: str, token: str, expiry_ts: int, - auth_provider_id: Optional[str], - auth_provider_session_id: Optional[str], + auth_provider_id: str | None, + auth_provider_session_id: str | None, ) -> None: """Adds a short-term login token for the given user. @@ -2455,9 +2451,9 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): async def user_delete_access_tokens( self, user_id: str, - except_token_id: Optional[int] = None, - device_id: Optional[str] = None, - ) -> list[tuple[str, int, Optional[str]]]: + except_token_id: int | None = None, + device_id: str | None = None, + ) -> list[tuple[str, int, str | None]]: """ Invalidate access and refresh tokens belonging to a user @@ -2471,14 +2467,14 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): A tuple of (token, token id, device id) for each of the deleted tokens """ - def f(txn: LoggingTransaction) -> list[tuple[str, int, Optional[str]]]: + def f(txn: LoggingTransaction) -> list[tuple[str, int, str | None]]: keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) - values: list[Union[str, int]] = [v for _, v in items] + values: list[str | int] = [v for _, v in items] # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where # clause and values before we handle that. This seems to be only used in the "set password" handler. @@ -2517,7 +2513,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): self, user_id: str, device_ids: StrCollection, - ) -> list[tuple[str, int, Optional[str]]]: + ) -> list[tuple[str, int, str | None]]: """ Invalidate access and refresh tokens belonging to a user @@ -2530,7 +2526,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): def user_delete_access_tokens_for_devices_txn( txn: LoggingTransaction, batch_device_ids: StrCollection - ) -> list[tuple[str, int, Optional[str]]]: + ) -> list[tuple[str, int, str | None]]: self.db_pool.simple_delete_many_txn( txn, table="refresh_tokens", @@ -2583,7 +2579,7 @@ class RegistrationWorkerStore(StatsStore, CacheInvalidationWorkerStore): await self.db_pool.runInteraction("delete_access_token", f) async def user_set_password_hash( - self, user_id: str, password_hash: Optional[str] + self, user_id: str, password_hash: str | None ) -> None: """ NB. This does *not* evict any cache because the one use for this @@ -2750,10 +2746,10 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self, user_id: str, token: str, - device_id: Optional[str], - valid_until_ms: Optional[int], - puppets_user_id: Optional[str] = None, - refresh_token_id: Optional[int] = None, + device_id: str | None, + valid_until_ms: int | None, + puppets_user_id: str | None = None, + refresh_token_id: int | None = None, ) -> int: """Adds an access token for the given user. @@ -2795,9 +2791,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self, user_id: str, token: str, - device_id: Optional[str], - expiry_ts: Optional[int], - ultimate_session_expiry_ts: Optional[int], + device_id: str | None, + expiry_ts: int | None, + ultimate_session_expiry_ts: int | None, ) -> int: """Adds a refresh token for the given user. @@ -2889,7 +2885,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): async def validate_threepid_session( self, session_id: str, client_secret: str, token: str, current_ts: int - ) -> Optional[str]: + ) -> str | None: """Attempt to validate a threepid session using a token Args: @@ -2909,7 +2905,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ # Insert everything into a transaction in order to run atomically - def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]: + def validate_threepid_session_txn(txn: LoggingTransaction) -> str | None: row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_session", @@ -2984,7 +2980,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): session_id: str, client_secret: str, send_attempt: int, - next_link: Optional[str], + next_link: str | None, token: str, token_expires: int, ) -> None: diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py index a603258644..c73c3d761d 100644 --- a/synapse/storage/databases/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -20,7 +20,6 @@ # import logging -from typing import Optional from synapse.storage._base import SQLBaseStore @@ -28,7 +27,7 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - async def get_rejection_reason(self, event_id: str) -> Optional[str]: + async def get_rejection_reason(self, event_id: str) -> str | None: return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 529102c245..9d9c37e2a4 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -24,9 +24,7 @@ from typing import ( Collection, Iterable, Mapping, - Optional, Sequence, - Union, cast, ) @@ -167,14 +165,14 @@ class RelationsWorkerStore(SQLBaseStore): room_id: str, event_id: str, event: EventBase, - relation_type: Optional[str] = None, - event_type: Optional[str] = None, + relation_type: str | None = None, + event_type: str | None = None, limit: int = 5, direction: Direction = Direction.BACKWARDS, - from_token: Optional[StreamToken] = None, - to_token: Optional[StreamToken] = None, + from_token: StreamToken | None = None, + to_token: StreamToken | None = None, recurse: bool = False, - ) -> tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: + ) -> tuple[Sequence[_RelatedEvent], StreamToken | None]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -204,7 +202,7 @@ class RelationsWorkerStore(SQLBaseStore): assert limit >= 0 where_clause = ["room_id = ?"] - where_args: list[Union[str, int]] = [room_id] + where_args: list[str | int] = [room_id] is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: @@ -276,7 +274,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_recent_references_for_event_txn( txn: LoggingTransaction, - ) -> tuple[list[_RelatedEvent], Optional[StreamToken]]: + ) -> tuple[list[_RelatedEvent], StreamToken | None]: txn.execute(sql, [event.event_id] + where_args + [limit + 1]) events = [] @@ -463,7 +461,7 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") async def get_references_for_events( self, event_ids: Collection[str] - ) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]: + ) -> Mapping[str, Sequence[_RelatedEvent] | None]: """Get a list of references to the given events. Args: @@ -511,14 +509,14 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() # type: ignore[synapse-@cached-mutable] - def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + def get_applicable_edit(self, event_id: str) -> EventBase | None: raise NotImplementedError() # TODO: This returns a mutable object, which is generally bad. @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") # type: ignore[synapse-@cached-mutable] async def get_applicable_edits( self, event_ids: Collection[str] - ) -> Mapping[str, Optional[EventBase]]: + ) -> Mapping[str, EventBase | None]: """Get the most recent edit (if any) that has happened for the given events. @@ -598,14 +596,14 @@ class RelationsWorkerStore(SQLBaseStore): } @cached() # type: ignore[synapse-@cached-mutable] - def get_thread_summary(self, event_id: str) -> Optional[tuple[int, EventBase]]: + def get_thread_summary(self, event_id: str) -> tuple[int, EventBase] | None: raise NotImplementedError() # TODO: This returns a mutable object, which is generally bad. @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") # type: ignore[synapse-@cached-mutable] async def get_thread_summaries( self, event_ids: Collection[str] - ) -> Mapping[str, Optional[tuple[int, EventBase]]]: + ) -> Mapping[str, tuple[int, EventBase] | None]: """Get the number of threaded replies and the latest reply (if any) for the given events. Args: @@ -826,8 +824,8 @@ class RelationsWorkerStore(SQLBaseStore): async def events_have_relations( self, parent_ids: list[str], - relation_senders: Optional[list[str]], - relation_types: Optional[list[str]], + relation_senders: list[str] | None, + relation_types: list[str] | None, ) -> list[str]: """Check which events have a relationship from the given senders of the given types. @@ -930,8 +928,8 @@ class RelationsWorkerStore(SQLBaseStore): self, room_id: str, limit: int = 5, - from_token: Optional[ThreadsNextBatch] = None, - ) -> tuple[Sequence[str], Optional[ThreadsNextBatch]]: + from_token: ThreadsNextBatch | None = None, + ) -> tuple[Sequence[str], ThreadsNextBatch | None]: """Get a list of thread IDs, ordered by topological ordering of their latest reply. @@ -971,7 +969,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_threads_txn( txn: LoggingTransaction, - ) -> tuple[list[str], Optional[ThreadsNextBatch]]: + ) -> tuple[list[str], ThreadsNextBatch | None]: txn.execute(sql, (room_id, *pagination_args, limit + 1)) rows = cast(list[tuple[str, int, int]], txn.fetchall()) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7a294de558..633df07736 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -28,8 +28,6 @@ from typing import ( Any, Collection, Mapping, - Optional, - Union, cast, ) @@ -82,24 +80,24 @@ class RatelimitOverride: @attr.s(slots=True, frozen=True, auto_attribs=True) class LargestRoomStats: room_id: str - name: Optional[str] - canonical_alias: Optional[str] + name: str | None + canonical_alias: str | None joined_members: int - join_rules: Optional[str] - guest_access: Optional[str] - history_visibility: Optional[str] + join_rules: str | None + guest_access: str | None + history_visibility: str | None state_events: int - avatar: Optional[str] - topic: Optional[str] - room_type: Optional[str] + avatar: str | None + topic: str | None + room_type: str | None @attr.s(slots=True, frozen=True, auto_attribs=True) class RoomStats(LargestRoomStats): joined_local_members: int - version: Optional[str] - creator: Optional[str] - encryption: Optional[str] + version: str | None + creator: str | None + encryption: str | None federatable: bool public: bool @@ -134,7 +132,7 @@ class RoomSortOrder(Enum): @attr.s(slots=True, frozen=True, auto_attribs=True) class PartialStateResyncInfo: - joined_via: Optional[str] + joined_via: str | None servers_in_room: set[str] = attr.ib(factory=set) @@ -205,7 +203,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> Optional[tuple[bool, bool]]: + async def get_room(self, room_id: str) -> tuple[bool, bool] | None: """Retrieve a room. Args: @@ -218,7 +216,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): or None if the room is unknown. """ row = cast( - Optional[tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]], + tuple[int | bool | None, int | bool | None] | None, await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, @@ -231,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return row return bool(row[0]), bool(row[1]) - async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: + async def get_room_with_stats(self, room_id: str) -> RoomStats | None: """Retrieve room with statistics. Args: @@ -242,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_room_with_stats_txn( txn: LoggingTransaction, room_id: str - ) -> Optional[RoomStats]: + ) -> RoomStats | None: sql = """ SELECT room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, @@ -292,8 +290,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) def _construct_room_type_where_clause( - self, room_types: Union[list[Union[str, None]], None] - ) -> tuple[Union[str, None], list]: + self, room_types: list[str | None] | None + ) -> tuple[str | None, list]: if not room_types: return None, [] @@ -320,9 +318,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): async def count_public_rooms( self, - network_tuple: Optional[ThirdPartyInstanceID], + network_tuple: ThirdPartyInstanceID | None, ignore_non_federatable: bool, - search_filter: Optional[dict], + search_filter: dict | None, ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. @@ -402,10 +400,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): async def get_largest_public_rooms( self, - network_tuple: Optional[ThirdPartyInstanceID], - search_filter: Optional[dict], - limit: Optional[int], - bounds: Optional[tuple[int, str]], + network_tuple: ThirdPartyInstanceID | None, + search_filter: dict | None, + limit: int | None, + bounds: tuple[int, str] | None, forwards: bool, ignore_non_federatable: bool = False, ) -> list[LargestRoomStats]: @@ -429,7 +427,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ where_clauses = [] - query_args: list[Union[str, int]] = [] + query_args: list[str | int] = [] if network_tuple: if network_tuple.appservice_id: @@ -575,7 +573,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) @cached(max_entries=10000) - async def is_room_blocked(self, room_id: str) -> Optional[bool]: + async def is_room_blocked(self, room_id: str) -> bool | None: return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, @@ -584,7 +582,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="is_room_blocked", ) - async def room_is_blocked_by(self, room_id: str) -> Optional[str]: + async def room_is_blocked_by(self, room_id: str) -> str | None: """ Function to retrieve user who has blocked the room. user_id is non-nullable @@ -604,9 +602,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): limit: int, order_by: str, reverse_order: bool, - search_term: Optional[str], - public_rooms: Optional[bool], - empty_rooms: Optional[bool], + search_term: str | None, + public_rooms: bool | None, + empty_rooms: bool | None, ) -> tuple[list[dict[str, Any]], int]: """Function to retrieve a paginated list of rooms as json. @@ -800,7 +798,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) @cached(max_entries=10000) - async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]: + async def get_ratelimit_for_user(self, user_id: str) -> RatelimitOverride | None: """Check if there are any overrides for ratelimiting for the given user Args: @@ -905,7 +903,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def get_retention_policy_for_room_txn( txn: LoggingTransaction, - ) -> Optional[tuple[Optional[int], Optional[int]]]: + ) -> tuple[int | None, int | None] | None: txn.execute( """ SELECT min_lifetime, max_lifetime FROM room_retention @@ -915,7 +913,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): (room_id,), ) - return cast(Optional[tuple[Optional[int], Optional[int]]], txn.fetchone()) + return cast(tuple[int | None, int | None] | None, txn.fetchone()) ret = await self.db_pool.runInteraction( "get_retention_policy_for_room", @@ -1058,7 +1056,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self, server_name: str, media_id: str, - quarantined_by: Optional[str], + quarantined_by: str | None, ) -> int: """quarantines or unquarantines a single local or remote media id @@ -1135,7 +1133,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): txn: LoggingTransaction, hashes: set[str], media_ids: set[str], - quarantined_by: Optional[str], + quarantined_by: str | None, ) -> int: """Quarantine and unquarantine local media items. @@ -1190,7 +1188,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): txn: LoggingTransaction, hashes: set[str], media: set[tuple[str, str]], - quarantined_by: Optional[str], + quarantined_by: str | None, ) -> int: """Quarantine and unquarantine remote items @@ -1238,7 +1236,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): txn: LoggingTransaction, local_mxcs: list[str], remote_mxcs: list[tuple[str, str]], - quarantined_by: Optional[str], + quarantined_by: str | None, ) -> int: """Quarantine and unquarantine local and remote media items @@ -1341,7 +1339,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) async def get_rooms_for_retention_period_in_range( - self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + self, min_ms: int | None, max_ms: int | None, include_null: bool = False ) -> dict[str, RetentionPolicy]: """Retrieves all of the rooms within the given retention range. @@ -1421,7 +1419,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): async def get_partial_state_servers_at_join( self, room_id: str - ) -> Optional[AbstractSet[str]]: + ) -> AbstractSet[str] | None: """Gets the set of servers in a partial state room at the time we joined it. Returns: @@ -1682,7 +1680,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_un_partial_stated_rooms_from_stream_txn, ) - async def get_event_report(self, report_id: int) -> Optional[dict[str, Any]]: + async def get_event_report(self, report_id: int) -> dict[str, Any] | None: """Retrieve an event report Args: @@ -1694,7 +1692,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): def _get_event_report_txn( txn: LoggingTransaction, report_id: int - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: sql = """ SELECT er.id, @@ -1748,9 +1746,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): start: int, limit: int, direction: Direction = Direction.BACKWARDS, - user_id: Optional[str] = None, - room_id: Optional[str] = None, - event_sender_user_id: Optional[str] = None, + user_id: str | None = None, + room_id: str | None = None, + event_sender_user_id: str | None = None, ) -> tuple[list[dict[str, Any]], int]: """Retrieve a paginated list of event reports @@ -2602,7 +2600,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, event_id: str, user_id: str, - reason: Optional[str], + reason: str | None, content: JsonDict, received_ts: int, ) -> int: @@ -2696,7 +2694,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) return next_id - async def clear_partial_state_room(self, room_id: str) -> Optional[int]: + async def clear_partial_state_room(self, room_id: str) -> int | None: """Clears the partial state flag for a room. Args: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1e22ab4e6d..4fb7779d38 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -26,9 +26,7 @@ from typing import ( Collection, Iterable, Mapping, - Optional, Sequence, - Union, cast, ) @@ -446,7 +444,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): async def get_invite_for_local_user_in_room( self, user_id: str, room_id: str - ) -> Optional[RoomsForUser]: + ) -> RoomsForUser | None: """Gets the invite for the given *local* user and room. Args: @@ -655,7 +653,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """Retrieve the current local membership state and event ID for a user in a room. Args: @@ -672,7 +670,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) results = cast( - Optional[tuple[str, str]], + tuple[str, str] | None, await self.db_pool.simple_select_one( "local_current_membership", {"room_id": room_id, "user_id": user_id}, @@ -833,7 +831,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def _do_users_share_a_room( self, user_id: str, other_user_ids: Collection[str] - ) -> Mapping[str, Optional[bool]]: + ) -> Mapping[str, bool | None]: """Return mapping from user ID to whether they share a room with the given user. @@ -896,7 +894,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def _do_users_share_a_room_joined_or_invited( self, user_id: str, other_user_ids: Collection[str] - ) -> Mapping[str, Optional[bool]]: + ) -> Mapping[str, bool | None]: """Return mapping from user ID to whether they share a room with the given user via being either joined or invited. @@ -974,7 +972,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): overlapping joined rooms for. cache_context """ - shared_room_ids: Optional[frozenset[str]] = None + shared_room_ids: frozenset[str] | None = None for user_id in user_ids: room_ids = await self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate @@ -1045,7 +1043,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) def _get_user_id_from_membership_event_id( self, event_id: str - ) -> Optional[tuple[str, ProfileInfo]]: + ) -> tuple[str, ProfileInfo] | None: raise NotImplementedError() @cachedList( @@ -1054,7 +1052,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Mapping[str, Optional[str]]: + ) -> Mapping[str, str | None]: """For given set of member event_ids check if they point to a join event. @@ -1229,7 +1227,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): async def _get_approximate_current_memberships_in_room( self, room_id: str - ) -> Mapping[str, Optional[str]]: + ) -> Mapping[str, str | None]: """Build a map from event id to membership, for all events in the current state. The event ids of non-memberships events (e.g. `m.room.power_levels`) are present @@ -1240,7 +1238,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ rows = cast( - list[tuple[str, Optional[str]]], + list[tuple[str, str | None]], await self.db_pool.simple_select_list( "current_state_events", keyvalues={"room_id": room_id}, @@ -1387,7 +1385,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): @cached(max_entries=5000) async def _get_membership_from_event_id( self, member_event_id: str - ) -> Optional[EventIdMembership]: + ) -> EventIdMembership | None: raise NotImplementedError() @cachedList( @@ -1395,7 +1393,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> Mapping[str, Optional[EventIdMembership]]: + ) -> Mapping[str, EventIdMembership | None]: """Get user_id and membership of a set of event IDs. Returns: @@ -1680,12 +1678,12 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): async def get_sliding_sync_room_for_user( self, user_id: str, room_id: str - ) -> Optional[RoomsForUserSlidingSync]: + ) -> RoomsForUserSlidingSync | None: """Get the sliding sync room entry for the given user and room.""" def get_sliding_sync_room_for_user_txn( txn: LoggingTransaction, - ) -> Optional[RoomsForUserSlidingSync]: + ) -> RoomsForUserSlidingSync | None: sql = """ SELECT m.room_id, m.sender, m.membership, m.membership_event_id, r.room_version, @@ -2106,7 +2104,7 @@ class _JoinedHostsCache: # if the instance is newly created or if the state is not based on a state # group. (An object is used as a sentinel value to ensure that it never is # equal to anything else). - state_group: Union[object, int] = attr.Factory(object) + state_group: object | int = attr.Factory(object) def __len__(self) -> int: return sum(len(v) for v in self.hosts_to_joined_users.values()) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 63489f5c27..d6eace5efa 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -28,8 +28,6 @@ from typing import ( Any, Collection, Iterable, - Optional, - Union, cast, ) @@ -60,7 +58,7 @@ class SearchEntry: value: str event_id: str room_id: str - stream_ordering: Optional[int] + stream_ordering: int | None origin_server_ts: int @@ -516,7 +514,7 @@ class SearchStore(SearchBackgroundUpdateStore): # List of tuples of (rank, room_id, event_id). results = cast( - list[tuple[Union[int, float], str, str]], + list[tuple[int | float, str, str]], await self.db_pool.execute("search_msgs", sql, *args), ) @@ -562,7 +560,7 @@ class SearchStore(SearchBackgroundUpdateStore): search_term: str, keys: Iterable[str], limit: int, - pagination_token: Optional[str] = None, + pagination_token: str | None = None, ) -> JsonDict: """Performs a full text search over events with given keys. @@ -683,7 +681,7 @@ class SearchStore(SearchBackgroundUpdateStore): # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering). results = cast( - list[tuple[Union[int, float], str, str, int, int]], + list[tuple[int | float, str, str, int, int]], await self.db_pool.execute("search_rooms", sql, *args), ) @@ -817,7 +815,7 @@ class SearchToken(enum.Enum): And = enum.auto() -Token = Union[str, Phrase, SearchToken] +Token = str | Phrase | SearchToken TokenList = list[Token] diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index 62463c0259..2b67e75ac4 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -14,7 +14,7 @@ import logging -from typing import TYPE_CHECKING, Mapping, Optional, cast +from typing import TYPE_CHECKING, Mapping, cast import attr @@ -79,7 +79,7 @@ class SlidingSyncStore(SQLBaseStore): async def get_latest_bump_stamp_for_room( self, room_id: str, - ) -> Optional[int]: + ) -> int | None: """ Get the `bump_stamp` for the room. @@ -99,7 +99,7 @@ class SlidingSyncStore(SQLBaseStore): """ return cast( - Optional[int], + int | None, await self.db_pool.simple_select_one_onecol( table="sliding_sync_joined_rooms", keyvalues={"room_id": room_id}, @@ -121,7 +121,7 @@ class SlidingSyncStore(SQLBaseStore): user_id: str, device_id: str, conn_id: str, - previous_connection_position: Optional[int], + previous_connection_position: int | None, per_connection_state: "MutablePerConnectionState", ) -> int: """Persist updates to the per-connection state for a sliding sync @@ -154,7 +154,7 @@ class SlidingSyncStore(SQLBaseStore): user_id: str, device_id: str, conn_id: str, - previous_connection_position: Optional[int], + previous_connection_position: int | None, per_connection_state: "PerConnectionStateDB", ) -> int: # First we fetch (or create) the connection key associated with the diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index c2c1b62d7e..a0aea4975c 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -28,9 +28,7 @@ from typing import ( Iterable, Mapping, MutableMapping, - Optional, TypeVar, - Union, cast, overload, ) @@ -86,8 +84,8 @@ class EventMetadata: room_id: str event_type: str - state_key: Optional[str] - rejection_reason: Optional[str] + state_key: str | None + rejection_reason: str | None def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: @@ -243,7 +241,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return result_map - async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: + async def get_room_predecessor(self, room_id: str) -> JsonMapping | None: """Get the predecessor of an upgraded room if it exists. Otherwise return None. @@ -303,7 +301,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=10000) - async def get_room_type(self, room_id: str) -> Union[Optional[str], Sentinel]: + async def get_room_type(self, room_id: str) -> str | None | Sentinel: """Fetch room type for given room. Since this function is cached, any missing values would be cached as @@ -325,7 +323,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @cachedList(cached_method_name="get_room_type", list_name="room_ids") async def bulk_get_room_type( self, room_ids: set[str] - ) -> Mapping[str, Union[Optional[str], Sentinel]]: + ) -> Mapping[str, str | None | Sentinel]: """ Bulk fetch room types for the given rooms (via current state). @@ -342,7 +340,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def txn( txn: LoggingTransaction, - ) -> MutableMapping[str, Union[Optional[str], Sentinel]]: + ) -> MutableMapping[str, str | None | Sentinel]: clause, args = make_in_list_sql_clause( txn.database_engine, "room_id", room_ids ) @@ -398,13 +396,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results @cached(max_entries=10000) - async def get_room_encryption(self, room_id: str) -> Optional[str]: + async def get_room_encryption(self, room_id: str) -> str | None: raise NotImplementedError() @cachedList(cached_method_name="get_room_encryption", list_name="room_ids") async def bulk_get_room_encryption( self, room_ids: set[str] - ) -> Mapping[str, Union[Optional[str], Sentinel]]: + ) -> Mapping[str, str | None | Sentinel]: """ Bulk fetch room encryption for the given rooms (via current state). @@ -422,7 +420,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def txn( txn: LoggingTransaction, - ) -> MutableMapping[str, Union[Optional[str], Sentinel]]: + ) -> MutableMapping[str, str | None | Sentinel]: clause, args = make_in_list_sql_clause( txn.database_engine, "room_id", room_ids ) @@ -551,7 +549,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # FIXME: how should this be cached? @cancellable async def get_partial_filtered_current_state_ids( - self, room_id: str, state_filter: Optional[StateFilter] = None + self, room_id: str, state_filter: StateFilter | None = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result @@ -604,7 +602,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) @cached(max_entries=50000) - async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + async def _get_state_group_for_event(self, event_id: str) -> int | None: return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, @@ -986,15 +984,13 @@ class StateMapWrapper(dict[StateKey, str]): return super().__getitem__(key) @overload # type: ignore[override] - def get(self, key: StateKey, default: None = None, /) -> Optional[str]: ... + def get(self, key: StateKey, default: None = None, /) -> str | None: ... @overload def get(self, key: StateKey, default: str, /) -> str: ... @overload - def get(self, key: StateKey, default: _T, /) -> Union[str, _T]: ... + def get(self, key: StateKey, default: _T, /) -> str | _T: ... - def get( - self, key: StateKey, default: Union[str, _T, None] = None - ) -> Union[str, _T, None]: + def get(self, key: StateKey, default: str | _T | None = None) -> str | _T | None: if key not in self.state_filter: raise Exception("State map was filtered and doesn't include: %s", key) return super().get(key, default) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 3df5c8b6f4..cd8f286d08 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import attr @@ -50,10 +50,10 @@ class StateDelta: event_type: str state_key: str - event_id: Optional[str] + event_id: str | None """new event_id for this state key. None if the state has been deleted.""" - prev_event_id: Optional[str] + prev_event_id: str | None """previous event_id for this state key. None if it's new state.""" @@ -191,8 +191,8 @@ class StateDeltasStore(SQLBaseStore): txn: LoggingTransaction, room_id: str, *, - from_token: Optional[RoomStreamToken], - to_token: Optional[RoomStreamToken], + from_token: RoomStreamToken | None, + to_token: RoomStreamToken | None, ) -> list[StateDelta]: """ Get the state deltas between two tokens. @@ -237,8 +237,8 @@ class StateDeltasStore(SQLBaseStore): self, room_id: str, *, - from_token: Optional[RoomStreamToken], - to_token: Optional[RoomStreamToken], + from_token: RoomStreamToken | None, + to_token: RoomStreamToken | None, ) -> list[StateDelta]: """ Get the state deltas between two tokens. diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 19e525a3cd..6568e2aa08 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -27,8 +27,6 @@ from typing import ( Any, Counter, Iterable, - Optional, - Union, cast, ) @@ -296,7 +294,7 @@ class StatsStore(StateDeltasStore): @cached() async def get_earliest_token_for_stats( self, stats_type: str, id: str - ) -> Optional[int]: + ) -> int | None: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -362,7 +360,7 @@ class StatsStore(StateDeltasStore): stats_id: str, fields: dict[str, int], complete_with_stream_id: int, - absolute_field_overrides: Optional[dict[str, int]] = None, + absolute_field_overrides: dict[str, int] | None = None, ) -> None: """ Updates the statistics for a subject, with a delta (difference/relative @@ -400,7 +398,7 @@ class StatsStore(StateDeltasStore): stats_id: str, fields: dict[str, int], complete_with_stream_id: int, - absolute_field_overrides: Optional[dict[str, int]] = None, + absolute_field_overrides: dict[str, int] | None = None, ) -> None: if absolute_field_overrides is None: absolute_field_overrides = {} @@ -585,7 +583,7 @@ class StatsStore(StateDeltasStore): ) return - room_state: dict[str, Union[None, bool, str]] = { + room_state: dict[str, None | bool | str] = { "join_rules": None, "history_visibility": None, "encryption": None, @@ -680,12 +678,12 @@ class StatsStore(StateDeltasStore): self, start: int, limit: int, - from_ts: Optional[int] = None, - until_ts: Optional[int] = None, - order_by: Optional[str] = UserSortOrder.USER_ID.value, + from_ts: int | None = None, + until_ts: int | None = None, + order_by: str | None = UserSortOrder.USER_ID.value, direction: Direction = Direction.FORWARDS, - search_term: Optional[str] = None, - ) -> tuple[list[tuple[str, Optional[str], int, int]], int]: + search_term: str | None = None, + ) -> tuple[list[tuple[str, str | None, int, int]], int]: """Function to retrieve a paginated list of users and their uploaded local media (size and number). This will return a json list of users and the total number of users matching the filter criteria. @@ -710,7 +708,7 @@ class StatsStore(StateDeltasStore): def get_users_media_usage_paginate_txn( txn: LoggingTransaction, - ) -> tuple[list[tuple[str, Optional[str], int, int]], int]: + ) -> tuple[list[tuple[str, str | None, int, int]], int]: filters = [] args: list = [] @@ -782,7 +780,7 @@ class StatsStore(StateDeltasStore): args += [limit, start] txn.execute(sql, args) - users = cast(list[tuple[str, Optional[str], int, int]], txn.fetchall()) + users = cast(list[tuple[str, str | None, int, int]], txn.fetchall()) return users, count diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index e8ea1e5480..8644ff412e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -50,7 +50,6 @@ from typing import ( Iterable, Literal, Mapping, - Optional, Protocol, cast, overload, @@ -102,7 +101,7 @@ class PaginateFunction(Protocol): *, room_id: str, from_key: RoomStreamToken, - to_key: Optional[RoomStreamToken] = None, + to_key: RoomStreamToken | None = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, ) -> tuple[list[EventBase], RoomStreamToken, bool]: ... @@ -112,7 +111,7 @@ class PaginateFunction(Protocol): @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: event_id: str - topological_ordering: Optional[int] + topological_ordering: int | None stream_ordering: int @@ -139,22 +138,22 @@ class CurrentStateDeltaMembership: room_id: str # Event - event_id: Optional[str] + event_id: str | None event_pos: PersistedEventPosition membership: str - sender: Optional[str] + sender: str | None # Prev event - prev_event_id: Optional[str] - prev_event_pos: Optional[PersistedEventPosition] - prev_membership: Optional[str] - prev_sender: Optional[str] + prev_event_id: str | None + prev_event_pos: PersistedEventPosition | None + prev_membership: str | None + prev_sender: str | None def generate_pagination_where_clause( direction: Direction, column_names: tuple[str, str], - from_token: Optional[tuple[Optional[int], int]], - to_token: Optional[tuple[Optional[int], int]], + from_token: tuple[int | None, int] | None, + to_token: tuple[int | None, int] | None, engine: BaseDatabaseEngine, ) -> str: """Creates an SQL expression to bound the columns by the pagination @@ -218,11 +217,9 @@ def generate_pagination_where_clause( def generate_pagination_bounds( direction: Direction, - from_token: Optional[RoomStreamToken], - to_token: Optional[RoomStreamToken], -) -> tuple[ - str, Optional[tuple[Optional[int], int]], Optional[tuple[Optional[int], int]] -]: + from_token: RoomStreamToken | None, + to_token: RoomStreamToken | None, +) -> tuple[str, tuple[int | None, int] | None, tuple[int | None, int] | None]: """ Generate a start and end point for this page of events. @@ -257,7 +254,7 @@ def generate_pagination_bounds( # by fetching all events between the min stream token and the maximum # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and # then filtering the results. - from_bound: Optional[tuple[Optional[int], int]] = None + from_bound: tuple[int | None, int] | None = None if from_token: if from_token.topological is not None: from_bound = from_token.as_historical_tuple() @@ -272,7 +269,7 @@ def generate_pagination_bounds( from_token.stream, ) - to_bound: Optional[tuple[Optional[int], int]] = None + to_bound: tuple[int | None, int] | None = None if to_token: if to_token.topological is not None: to_bound = to_token.as_historical_tuple() @@ -291,7 +288,7 @@ def generate_pagination_bounds( def generate_next_token( - direction: Direction, last_topo_ordering: Optional[int], last_stream_ordering: int + direction: Direction, last_topo_ordering: int | None, last_stream_ordering: int ) -> RoomStreamToken: """ Generate the next room stream token based on the currently returned data. @@ -317,7 +314,7 @@ def generate_next_token( def _make_generic_sql_bound( bound: str, column_names: tuple[str, str], - values: tuple[Optional[int], int], + values: tuple[int | None, int], engine: BaseDatabaseEngine, ) -> str: """Create an SQL expression that bounds the given column names by the @@ -381,9 +378,9 @@ def _make_generic_sql_bound( def _filter_results( - lower_token: Optional[RoomStreamToken], - upper_token: Optional[RoomStreamToken], - instance_name: Optional[str], + lower_token: RoomStreamToken | None, + upper_token: RoomStreamToken | None, + instance_name: str | None, topological_ordering: int, stream_ordering: int, ) -> bool: @@ -436,9 +433,9 @@ def _filter_results( def _filter_results_by_stream( - lower_token: Optional[RoomStreamToken], - upper_token: Optional[RoomStreamToken], - instance_name: Optional[str], + lower_token: RoomStreamToken | None, + upper_token: RoomStreamToken | None, + instance_name: str | None, stream_ordering: int, ) -> bool: """ @@ -480,7 +477,7 @@ def _filter_results_by_stream( return True -def filter_to_clause(event_filter: Optional[Filter]) -> tuple[str, list[str]]: +def filter_to_clause(event_filter: Filter | None) -> tuple[str, list[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -662,7 +659,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): *, room_ids: Collection[str], from_key: RoomStreamToken, - to_key: Optional[RoomStreamToken] = None, + to_key: RoomStreamToken | None = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, ) -> dict[str, tuple[list[EventBase], RoomStreamToken, bool]]: @@ -784,7 +781,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): *, room_id: str, from_key: RoomStreamToken, - to_key: Optional[RoomStreamToken] = None, + to_key: RoomStreamToken | None = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, ) -> tuple[list[EventBase], RoomStreamToken, bool]: @@ -936,7 +933,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken, - excluded_room_ids: Optional[list[str]] = None, + excluded_room_ids: list[str] | None = None, ) -> list[CurrentStateDeltaMembership]: """ Fetch membership events (and the previous event that was replaced by that one) @@ -1131,7 +1128,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken, - excluded_room_ids: Optional[AbstractSet[str]] = None, + excluded_room_ids: AbstractSet[str] | None = None, ) -> dict[str, RoomsForUserStateReset]: """ Fetch membership events that result in a meaningful membership change for a @@ -1328,7 +1325,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken, - excluded_rooms: Optional[list[str]] = None, + excluded_rooms: list[str] | None = None, ) -> list[EventBase]: """Fetch membership events for a given user. @@ -1455,7 +1452,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int - ) -> Optional[tuple[int, int, str]]: + ) -> tuple[int, int, str] | None: """Gets details of the first event in a room at or before a stream ordering Args: @@ -1466,7 +1463,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): A tuple of (stream ordering, topological ordering, event_id) """ - def _f(txn: LoggingTransaction) -> Optional[tuple[int, int, str]]: + def _f(txn: LoggingTransaction) -> tuple[int, int, str] | None: sql = """ SELECT stream_ordering, topological_ordering, event_id FROM events @@ -1479,7 +1476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LIMIT 1 """ txn.execute(sql, (room_id, stream_ordering)) - return cast(Optional[tuple[int, int, str]], txn.fetchone()) + return cast(tuple[int, int, str] | None, txn.fetchone()) return await self.db_pool.runInteraction( "get_room_event_before_stream_ordering", _f @@ -1489,7 +1486,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self, room_id: str, end_token: RoomStreamToken, - ) -> Optional[str]: + ) -> str | None: """Returns the ID of the last event in a room at or before a stream ordering Args: @@ -1514,8 +1511,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_last_event_pos_in_room( self, room_id: str, - event_types: Optional[StrCollection] = None, - ) -> Optional[tuple[str, PersistedEventPosition]]: + event_types: StrCollection | None = None, + ) -> tuple[str, PersistedEventPosition] | None: """ Returns the ID and event position of the last event in a room. @@ -1532,7 +1529,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_last_event_pos_in_room_txn( txn: LoggingTransaction, - ) -> Optional[tuple[str, PersistedEventPosition]]: + ) -> tuple[str, PersistedEventPosition] | None: event_type_clause = "" event_type_args: list[str] = [] if event_types is not None and len(event_types) > 0: @@ -1558,7 +1555,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): [room_id] + event_type_args, ) - row = cast(Optional[tuple[str, int, str]], txn.fetchone()) + row = cast(tuple[str, int, str] | None, txn.fetchone()) if row is not None: event_id, stream_ordering, instance_name = row @@ -1580,8 +1577,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self, room_id: str, end_token: RoomStreamToken, - event_types: Optional[StrCollection] = None, - ) -> Optional[tuple[str, PersistedEventPosition]]: + event_types: StrCollection | None = None, + ) -> tuple[str, PersistedEventPosition] | None: """ Returns the ID and event position of the last event in a room at or before a stream ordering. @@ -1598,7 +1595,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_last_event_pos_in_room_before_stream_ordering_txn( txn: LoggingTransaction, - ) -> Optional[tuple[str, PersistedEventPosition]]: + ) -> tuple[str, PersistedEventPosition] | None: # We're looking for the closest event at or before the token. We need to # handle the fact that the stream token can be a vector clock (with an # `instance_map`) and events can be persisted on different instances @@ -1735,7 +1732,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): @cachedList(cached_method_name="_get_max_event_pos", list_name="room_ids") async def _bulk_get_max_event_pos( self, room_ids: StrCollection - ) -> Mapping[str, Optional[int]]: + ) -> Mapping[str, int | None]: """Fetch the max position of a persisted event in the room.""" # We need to be careful not to return positions ahead of the current @@ -1860,14 +1857,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn: LoggingTransaction, event_id: str, allow_none: bool = False, - ) -> Optional[int]: ... + ) -> int | None: ... def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none: bool = False, - ) -> Optional[int]: + ) -> int | None: # Type ignore: we pass keyvalues a Dict[str, str]; the function wants # Dict[str, Any]. I think mypy is unhappy because Dict is invariant? return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload] @@ -1970,7 +1967,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_id: str, before_limit: int, after_limit: int, - event_filter: Optional[Filter] = None, + event_filter: Filter | None = None, ) -> _EventsAround: """Retrieve events and pagination tokens around a given event in a room. @@ -2008,7 +2005,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_id: str, before_limit: int, after_limit: int, - event_filter: Optional[Filter], + event_filter: Filter | None, ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. @@ -2073,7 +2070,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id: int, current_id: int, limit: int, - ) -> tuple[int, dict[str, Optional[int]]]: + ) -> tuple[int, dict[str, int | None]]: """Get all new events Returns all event ids with from_id < stream_ordering <= current_id. @@ -2094,7 +2091,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_all_new_event_ids_stream_txn( txn: LoggingTransaction, - ) -> tuple[int, dict[str, Optional[int]]]: + ) -> tuple[int, dict[str, int | None]]: sql = ( "SELECT e.stream_ordering, e.event_id, e.received_ts" " FROM events AS e" @@ -2111,7 +2108,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if len(rows) == limit: upper_bound = rows[-1][0] - event_to_received_ts: dict[str, Optional[int]] = { + event_to_received_ts: dict[str, int | None] = { row[1]: row[2] for row in rows } return upper_bound, event_to_received_ts @@ -2221,10 +2218,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn: LoggingTransaction, room_id: str, from_token: RoomStreamToken, - to_token: Optional[RoomStreamToken] = None, + to_token: RoomStreamToken | None = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - event_filter: Optional[Filter] = None, + event_filter: Filter | None = None, ) -> tuple[list[_EventDictReturn], RoomStreamToken, bool]: """Returns list of events before or after a given token. @@ -2395,10 +2392,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): *, room_id: str, from_key: RoomStreamToken, - to_key: Optional[RoomStreamToken] = None, + to_key: RoomStreamToken | None = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - event_filter: Optional[Filter] = None, + event_filter: Filter | None = None, ) -> tuple[list[EventBase], RoomStreamToken, bool]: """ Paginate events by `topological_ordering` (tie-break with `stream_ordering`) in @@ -2525,9 +2522,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_timeline_gaps( self, room_id: str, - from_token: Optional[RoomStreamToken], + from_token: RoomStreamToken | None, to_token: RoomStreamToken, - ) -> Optional[RoomStreamToken]: + ) -> RoomStreamToken | None: """Check if there is a gap, and return a token that marks the position of the gap in the stream. """ diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py index 7410507255..05ebb57cf3 100644 --- a/synapse/storage/databases/main/task_scheduler.py +++ b/synapse/storage/databases/main/task_scheduler.py @@ -19,7 +19,7 @@ # # -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -63,11 +63,11 @@ class TaskSchedulerWorkerStore(SQLBaseStore): async def get_scheduled_tasks( self, *, - actions: Optional[list[str]] = None, - resource_id: Optional[str] = None, - statuses: Optional[list[TaskStatus]] = None, - max_timestamp: Optional[int] = None, - limit: Optional[int] = None, + actions: list[str] | None = None, + resource_id: str | None = None, + statuses: list[TaskStatus] | None = None, + max_timestamp: int | None = None, + limit: int | None = None, ) -> list[ScheduledTask]: """Get a list of scheduled tasks from the DB. @@ -152,9 +152,9 @@ class TaskSchedulerWorkerStore(SQLBaseStore): id: str, timestamp: int, *, - status: Optional[TaskStatus] = None, - result: Optional[JsonMapping] = None, - error: Optional[str] = None, + status: TaskStatus | None = None, + result: JsonMapping | None = None, + error: str | None = None, ) -> bool: """Update a scheduled task in the DB with some new value(s). @@ -182,7 +182,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore): ) return nb_rows > 0 - async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]: + async def get_scheduled_task(self, id: str) -> ScheduledTask | None: """Get a specific `ScheduledTask` from its id. Args: @@ -191,7 +191,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore): Returns: the task if available, `None` otherwise """ row = cast( - Optional[ScheduledTaskRow], + ScheduledTaskRow | None, await self.db_pool.simple_select_one( table="scheduled_tasks", keyvalues={"id": id}, diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index 1c02ab1611..e177e67ab1 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -15,8 +15,6 @@ from typing import ( TYPE_CHECKING, Any, Iterable, - Optional, - Union, cast, ) @@ -162,8 +160,8 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): room_id: str, thread_root_event_id: str, *, - automatic_event_orderings: Optional[EventOrderings], - ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: + automatic_event_orderings: EventOrderings | None, + ) -> int | AutomaticSubscriptionConflicted | None: """Updates a user's subscription settings for a specific thread root. If no change would be made to the subscription, does not produce any database change. @@ -205,7 +203,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): def _subscribe_user_to_thread_txn( txn: LoggingTransaction, - ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: + ) -> int | AutomaticSubscriptionConflicted | None: requested_automatic = automatic_event_orderings is not None row = self.db_pool.simple_select_one_txn( @@ -307,7 +305,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): async def unsubscribe_user_from_thread( self, user_id: str, room_id: str, thread_root_event_id: str - ) -> Optional[int]: + ) -> int | None: """Unsubscribes a user from a thread. If no change would be made to the subscription, does not produce any database change. @@ -323,7 +321,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): assert self._can_write_to_thread_subscriptions - def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> Optional[int]: + def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> int | None: already_subscribed = self.db_pool.simple_select_one_onecol_txn( txn, table="thread_subscriptions", @@ -420,7 +418,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): @cached(tree=True) async def get_subscription_for_thread( self, user_id: str, room_id: str, thread_root_event_id: str - ) -> Optional[ThreadSubscription]: + ) -> ThreadSubscription | None: """Get the thread subscription for a specific thread and user. Args: @@ -540,7 +538,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): async def get_latest_updated_thread_subscriptions_for_user( self, user_id: str, *, from_id: int, to_id: int, limit: int - ) -> list[tuple[int, str, str, bool, Optional[bool]]]: + ) -> list[tuple[int, str, str, bool, bool | None]]: """Get the latest updates to thread subscriptions for a specific user. Args: @@ -558,7 +556,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): def get_updated_thread_subscriptions_for_user_txn( txn: LoggingTransaction, - ) -> list[tuple[int, str, str, bool, Optional[bool]]]: + ) -> list[tuple[int, str, str, bool, bool | None]]: sql = """ WITH the_updates AS ( SELECT stream_id, room_id, event_id, subscribed, automatic diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index e0422f7459..70c5b928fd 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -21,7 +21,7 @@ import logging from enum import Enum -from typing import TYPE_CHECKING, Iterable, Mapping, Optional, cast +from typing import TYPE_CHECKING, Iterable, Mapping, cast import attr from canonicaljson import encode_canonical_json @@ -97,7 +97,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): async def get_received_txn_response( self, transaction_id: str, origin: str - ) -> Optional[tuple[int, JsonDict]]: + ) -> tuple[int, JsonDict] | None: """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response body (as a dict). @@ -120,7 +120,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def _get_received_txn_response( self, txn: LoggingTransaction, transaction_id: str, origin: str - ) -> Optional[tuple[int, JsonDict]]: + ) -> tuple[int, JsonDict] | None: result = self.db_pool.simple_select_one_txn( txn, table="received_transactions", @@ -169,7 +169,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): async def get_destination_retry_timings( self, destination: str, - ) -> Optional[DestinationRetryTimings]: + ) -> DestinationRetryTimings | None: """Gets the current retry timings (if any) for a given destination. Args: @@ -190,7 +190,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def _get_destination_retry_timings( self, txn: LoggingTransaction, destination: str - ) -> Optional[DestinationRetryTimings]: + ) -> DestinationRetryTimings | None: result = self.db_pool.simple_select_one_txn( txn, table="destinations", @@ -213,9 +213,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) async def get_destination_retry_timings_batch( self, destinations: StrCollection - ) -> Mapping[str, Optional[DestinationRetryTimings]]: + ) -> Mapping[str, DestinationRetryTimings | None]: rows = cast( - list[tuple[str, Optional[int], Optional[int], Optional[int]]], + list[tuple[str, int | None, int | None, int | None]], await self.db_pool.simple_select_many_batch( table="destinations", iterable=destinations, @@ -241,7 +241,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): async def set_destination_retry_timings( self, destination: str, - failure_ts: Optional[int], + failure_ts: int | None, retry_last_ts: int, retry_interval: int, ) -> None: @@ -269,7 +269,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): self, txn: LoggingTransaction, destination: str, - failure_ts: Optional[int], + failure_ts: int | None, retry_last_ts: int, retry_interval: int, ) -> None: @@ -337,7 +337,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): async def get_destination_last_successful_stream_ordering( self, destination: str - ) -> Optional[int]: + ) -> int | None: """ Gets the stream ordering of the PDU most-recently successfully sent to the specified destination, or None if this information has not been @@ -420,7 +420,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): return event_ids async def get_catch_up_outstanding_destinations( - self, after_destination: Optional[str] + self, after_destination: str | None ) -> list[str]: """ Get a list of destinations we should retry transaction sending to. @@ -449,7 +449,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): @staticmethod def _get_catch_up_outstanding_destinations_txn( - txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str] + txn: LoggingTransaction, now_time_ms: int, after_destination: str | None ) -> list[str]: # We're looking for destinations which satisfy either of the following # conditions: @@ -537,11 +537,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): self, start: int, limit: int, - destination: Optional[str] = None, + destination: str | None = None, order_by: str = DestinationSortOrder.DESTINATION.value, direction: Direction = Direction.FORWARDS, ) -> tuple[ - list[tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]], + list[tuple[str, int | None, int | None, int | None, int | None]], int, ]: """Function to retrieve a paginated list of destinations. @@ -567,9 +567,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def get_destinations_paginate_txn( txn: LoggingTransaction, ) -> tuple[ - list[ - tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]] - ], + list[tuple[str, int | None, int | None, int | None, int | None]], int, ]: order_by_column = DestinationSortOrder(order_by).value @@ -599,11 +597,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): """ txn.execute(sql, args + [limit, start]) destinations = cast( - list[ - tuple[ - str, Optional[int], Optional[int], Optional[int], Optional[int] - ] - ], + list[tuple[str, int | None, int | None, int | None, int | None]], txn.fetchall(), ) return destinations, count diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 69a4431f29..e523f0238a 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Any, Optional, Union, cast +from typing import Any, cast import attr @@ -142,7 +142,7 @@ class UIAuthWorkerStore(SQLBaseStore): self, session_id: str, stage_type: str, - result: Union[str, bool, JsonDict], + result: str | bool | JsonDict, ) -> None: """ Mark a session stage as completed. @@ -170,7 +170,7 @@ class UIAuthWorkerStore(SQLBaseStore): async def get_completed_ui_auth_stages( self, session_id: str - ) -> dict[str, Union[str, bool, JsonDict]]: + ) -> dict[str, str | bool | JsonDict]: """ Retrieve the completed stages of a UI authentication session. @@ -262,7 +262,7 @@ class UIAuthWorkerStore(SQLBaseStore): ) async def get_ui_auth_session_data( - self, session_id: str, key: str, default: Optional[Any] = None + self, session_id: str, key: str, default: Any | None = None ) -> Any: """ Retrieve data stored with set_session_data diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 7a57beee71..6c5abc71ae 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,7 +26,6 @@ from typing import ( TYPE_CHECKING, Collection, Iterable, - Optional, Sequence, TypedDict, cast, @@ -72,8 +71,8 @@ class _UserDirProfile: user_id: str # If the display name or avatar URL are unexpected types, replace with None - display_name: Optional[str] = attr.ib(default=None, converter=non_null_str_or_none) - avatar_url: Optional[str] = attr.ib(default=None, converter=non_null_str_or_none) + display_name: str | None = attr.ib(default=None, converter=non_null_str_or_none) + avatar_url: str | None = attr.ib(default=None, converter=non_null_str_or_none) class UserDirectoryBackgroundUpdateStore(StateDeltasStore): @@ -206,7 +205,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): def _get_next_batch( txn: LoggingTransaction, - ) -> Optional[Sequence[tuple[str, int]]]: + ) -> Sequence[tuple[str, int]] | None: # Only fetch 250 rooms, so we don't fetch too many at once, even # if those 250 rooms have less than batch_size state events. sql = """ @@ -352,7 +351,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): def _populate_user_directory_process_users_txn( txn: LoggingTransaction, - ) -> Optional[int]: + ) -> int | None: # Note: we use an ORDER BY in the SELECT to force usage of an # index. Otherwise, postgres does a sequential scan that is # surprisingly slow (I think due to the fact it will read/skip @@ -397,7 +396,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # Next fetch their profiles. Note that not all users have profiles. profile_rows = cast( - list[tuple[str, Optional[str], Optional[str]]], + list[tuple[str, str | None, str | None]], self.db_pool.simple_select_many_txn( txn, table="profiles", @@ -492,7 +491,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ] rows = cast( - list[tuple[str, Optional[str]]], + list[tuple[str, str | None]], self.db_pool.simple_select_many_txn( txn, table="users", @@ -646,7 +645,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) async def update_profile_in_user_dir( - self, user_id: str, display_name: Optional[str], avatar_url: Optional[str] + self, user_id: str, display_name: str | None, avatar_url: str | None ) -> None: """ Update or add a user's profile in the user directory. @@ -812,7 +811,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def _get_user_in_directory( self, user_id: str - ) -> Optional[tuple[Optional[str], Optional[str]]]: + ) -> tuple[str | None, str | None] | None: """ Fetch the user information in the user directory. @@ -821,7 +820,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): avatar URL (both of which may be None). """ return cast( - Optional[tuple[Optional[str], Optional[str]]], + tuple[str | None, str | None] | None, await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, @@ -831,7 +830,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ), ) - async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: + async def update_user_directory_stream_pos(self, stream_id: int | None) -> None: await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, @@ -971,7 +970,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_user_directory_stream_pos(self) -> Optional[int]: + async def get_user_directory_stream_pos(self) -> int | None: """ Get the stream ID of the user directory stream. @@ -1144,7 +1143,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): raise Exception("Unrecognized database engine") results = cast( - list[tuple[str, Optional[str], Optional[str]]], + list[tuple[str, str | None, str | None]], await self.db_pool.execute("search_user_dir", sql, *args), ) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index a0d8667b07..8c505041f0 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -23,8 +23,6 @@ import logging from typing import ( TYPE_CHECKING, Mapping, - Optional, - Union, ) from synapse.logging.opentracing import tag_args, trace @@ -82,7 +80,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): else: # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group: Optional[int] = state_group + next_group: int | None = state_group count = 0 while next_group: @@ -104,7 +102,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): self, txn: LoggingTransaction, groups: list[int], - state_filter: Optional[StateFilter] = None, + state_filter: StateFilter | None = None, ) -> Mapping[int, StateMap[str]]: """ Given a number of state groups, fetch the latest state for each group. @@ -144,7 +142,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): %s """ - overall_select_query_args: list[Union[int, str]] = [] + overall_select_query_args: list[int | str] = [] # This is an optimization to create a select clause per-condition. This # makes the query planner a lot smarter on what rows should pull out in the @@ -153,7 +151,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): use_condition_optimization = ( not state_filter.include_others and not state_filter.is_full() ) - state_filter_condition_combos: list[tuple[str, Optional[str]]] = [] + state_filter_condition_combos: list[tuple[str, str | None]] = [] # We don't need to caclculate this list if we're not using the condition # optimization if use_condition_optimization: @@ -213,7 +211,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): """ for group in groups: - args: list[Union[int, str]] = [group] + args: list[int | str] = [group] args.extend(overall_select_query_args) txn.execute(sql % (overall_select_clause,), args) @@ -235,7 +233,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): # # We just haven't put in the time to refactor this. for group in groups: - next_group: Optional[int] = group + next_group: int | None = group while next_group: # We did this before by getting the list of group ids, and diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py index 6975690c51..23150e8626 100644 --- a/synapse/storage/databases/state/deletion.py +++ b/synapse/storage/databases/state/deletion.py @@ -20,7 +20,6 @@ from typing import ( AsyncIterator, Collection, Mapping, - Optional, ) from synapse.events.snapshot import EventPersistencePair @@ -506,7 +505,7 @@ class StateDeletionDataStore: async def get_next_state_group_collection_to_delete( self, - ) -> Optional[tuple[str, Mapping[int, int]]]: + ) -> tuple[str, Mapping[int, int]] | None: """Get the next set of state groups to try and delete Returns: @@ -520,7 +519,7 @@ class StateDeletionDataStore: def _get_next_state_group_collection_to_delete_txn( self, txn: LoggingTransaction, - ) -> Optional[tuple[str, Mapping[int, int]]]: + ) -> tuple[str, Mapping[int, int]] | None: """Implementation of `get_next_state_group_collection_to_delete`""" # We want to return chunks of state groups that were marked for deletion diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 6f25e7f0bc..d3ce7a8b55 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -24,7 +24,6 @@ from typing import ( TYPE_CHECKING, Iterable, Mapping, - Optional, cast, ) @@ -69,8 +68,8 @@ class _GetStateGroupDelta: us use the iterable flag when caching """ - prev_group: Optional[int] - delta_ids: Optional[StateMap[str]] + prev_group: int | None + delta_ids: StateMap[str] | None def __len__(self) -> int: return len(self.delta_ids) if self.delta_ids else 0 @@ -279,7 +278,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): @tag_args @cancellable async def _get_state_for_groups( - self, groups: Iterable[int], state_filter: Optional[StateFilter] = None + self, groups: Iterable[int], state_filter: StateFilter | None = None ) -> dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -571,9 +570,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self, event_id: str, room_id: str, - prev_group: Optional[int], - delta_ids: Optional[StateMap[str]], - current_state_ids: Optional[StateMap[str]], + prev_group: int | None, + delta_ids: StateMap[str] | None, + current_state_ids: StateMap[str] | None, ) -> int: """Store a new set of state, returning a newly assigned state group. @@ -602,7 +601,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): def insert_delta_group_txn( txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str] - ) -> Optional[int]: + ) -> int | None: """Try and persist the new group as a delta. Requires that we have the state as a delta from a previous state group. diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index be6981f77c..026b742aad 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -20,7 +20,7 @@ # import abc from enum import IntEnum -from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar from synapse.storage.types import Connection, Cursor, DBAPI2Module @@ -123,7 +123,7 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM @abc.abstractmethod def attempt_to_set_isolation_level( - self, conn: ConnectionType, isolation_level: Optional[int] + self, conn: ConnectionType, isolation_level: int | None ) -> None: """Attempt to set the connections isolation level. diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index b059b924c2..cc7e5508fd 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, cast +from typing import TYPE_CHECKING, Any, Mapping, NoReturn, cast import psycopg2.extensions @@ -60,10 +60,10 @@ class PostgresEngine( # some degenerate query plan has been created and the client has probably # timed out/walked off anyway. # This is in milliseconds. - self.statement_timeout: Optional[int] = database_config.get( + self.statement_timeout: int | None = database_config.get( "statement_timeout", 60 * 60 * 1000 ) - self._version: Optional[int] = None # unknown as yet + self._version: int | None = None # unknown as yet self.isolation_level_map: Mapping[int, int] = { IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED, @@ -234,7 +234,7 @@ class PostgresEngine( return conn.set_session(autocommit=autocommit) def attempt_to_set_isolation_level( - self, conn: psycopg2.extensions.connection, isolation_level: Optional[int] + self, conn: psycopg2.extensions.connection, isolation_level: int | None ) -> None: if isolation_level is None: isolation_level = self.default_isolation_level diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index b49d230eed..3b1b19c00e 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -22,7 +22,7 @@ import platform import sqlite3 import struct import threading -from typing import TYPE_CHECKING, Any, Mapping, Optional +from typing import TYPE_CHECKING, Any, Mapping from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines._base import AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER @@ -45,7 +45,7 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): # A connection to a database that has already been prepared, to use as a # base for an in-memory connection. This is used during unit tests to # speed up setting up the DB. - self._prepped_conn: Optional[sqlite3.Connection] = database_config.get( + self._prepped_conn: sqlite3.Connection | None = database_config.get( "_TEST_PREPPED_CONN" ) @@ -141,7 +141,7 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): pass def attempt_to_set_isolation_level( - self, conn: sqlite3.Connection, isolation_level: Optional[int] + self, conn: sqlite3.Connection, isolation_level: int | None ) -> None: # All transactions are SERIALIZABLE by default in sqlite pass diff --git a/synapse/storage/invite_rule.py b/synapse/storage/invite_rule.py index f63390871e..3de77e8c21 100644 --- a/synapse/storage/invite_rule.py +++ b/synapse/storage/invite_rule.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from typing import Optional, Pattern +from typing import Pattern from matrix_common.regex import glob_to_regex @@ -20,7 +20,7 @@ class InviteRule(Enum): class InviteRulesConfig: """Class to determine if a given user permits an invite from another user, and the action to take.""" - def __init__(self, account_data: Optional[JsonMapping]): + def __init__(self, account_data: JsonMapping | None): self.allowed_users: list[Pattern[str]] = [] self.ignored_users: list[Pattern[str]] = [] self.blocked_users: list[Pattern[str]] = [] @@ -30,7 +30,7 @@ class InviteRulesConfig: self.blocked_servers: list[Pattern[str]] = [] def process_field( - values: Optional[list[str]], + values: list[str] | None, ruleset: list[Pattern[str]], rule: InviteRule, ) -> None: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index d4bd8020e1..2def1e130c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -28,7 +28,6 @@ from typing import ( Counter as CounterType, Generator, Iterable, - Optional, TextIO, ) @@ -75,7 +74,7 @@ class _SchemaState: current_version: int = attr.ib() """The current schema version of the database""" - compat_version: Optional[int] = attr.ib() + compat_version: int | None = attr.ib() """The SCHEMA_VERSION of the oldest version of Synapse for this database If this is None, we have an old version of the database without the necessary @@ -95,7 +94,7 @@ class _SchemaState: def prepare_database( db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, - config: Optional[HomeServerConfig], + config: HomeServerConfig | None, databases: Collection[str] = ("main", "state"), ) -> None: """Prepares a physical database for usage. Will either create all necessary tables @@ -307,7 +306,7 @@ def _upgrade_existing_database( cur: LoggingTransaction, current_schema_state: _SchemaState, database_engine: BaseDatabaseEngine, - config: Optional[HomeServerConfig], + config: HomeServerConfig | None, databases: Collection[str], is_empty: bool = False, ) -> None: @@ -683,7 +682,7 @@ def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None: def _get_or_create_schema_state( txn: Cursor, database_engine: BaseDatabaseEngine -) -> Optional[_SchemaState]: +) -> _SchemaState | None: # Bluntly try creating the schema_version tables. sql_path = os.path.join(schema_path, "common", "schema_version.sql") database_engine.execute_script_file(txn, sql_path) @@ -698,7 +697,7 @@ def _get_or_create_schema_state( current_version = int(row[0]) upgraded = bool(row[1]) - compat_version: Optional[int] = None + compat_version: int | None = None txn.execute("SELECT compat_version FROM schema_compat_version") row = txn.fetchone() if row is not None: diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 35da5351f8..4c1ace28e7 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -20,7 +20,6 @@ # import logging -from typing import Optional import attr @@ -42,14 +41,14 @@ class RoomsForUser: @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class RoomsForUserSlidingSync: room_id: str - sender: Optional[str] + sender: str | None membership: str - event_id: Optional[str] + event_id: str | None event_pos: PersistedEventPosition room_version_id: str has_known_state: bool - room_type: Optional[str] + room_type: str | None is_encrypted: bool @@ -60,9 +59,9 @@ class RoomsForUserStateReset: without a corresponding event so that information isn't always available.""" room_id: str - sender: Optional[str] + sender: str | None membership: str - event_id: Optional[str] + event_id: str | None event_pos: PersistedEventPosition room_version_id: str @@ -75,8 +74,8 @@ class GetRoomsForUserWithStreamOrdering: @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class ProfileInfo: - avatar_url: Optional[str] - display_name: Optional[str] + avatar_url: str | None + display_name: str | None # TODO This is used as a cached value and is mutable. diff --git a/synapse/storage/types.py b/synapse/storage/types.py index fedf10dfc0..ad9e5391e3 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -24,17 +24,15 @@ from typing import ( Callable, Iterator, Mapping, - Optional, Protocol, Sequence, - Union, ) """ Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ -SQLQueryParameters = Union[Sequence[Any], Mapping[str, Any]] +SQLQueryParameters = Sequence[Any] | Mapping[str, Any] class Cursor(Protocol): @@ -44,16 +42,16 @@ class Cursor(Protocol): self, sql: str, parameters: Sequence[SQLQueryParameters] ) -> Any: ... - def fetchone(self) -> Optional[tuple]: ... + def fetchone(self) -> tuple | None: ... - def fetchmany(self, size: Optional[int] = ...) -> list[tuple]: ... + def fetchmany(self, size: int | None = ...) -> list[tuple]: ... def fetchall(self) -> list[tuple]: ... @property def description( self, - ) -> Optional[Sequence[Any]]: + ) -> Sequence[Any] | None: # At the time of writing, Synapse only assumes that `column[0]: str` for each # `column in description`. Since this is hard to express in the type system, and # as this is rarely used in Synapse, we deem `column: Any` good enough. @@ -81,10 +79,10 @@ class Connection(Protocol): def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> Optional[bool]: ... + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: ... class DBAPI2Module(Protocol): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5bf5c2b4bf..66c993cbd9 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -30,10 +30,8 @@ from typing import ( ContextManager, Generic, Iterable, - Optional, Sequence, TypeVar, - Union, cast, ) @@ -619,7 +617,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids.difference_update(next_ids) self._finished_ids.update(next_ids) - new_cur: Optional[int] = None + new_cur: int | None = None if self._unfinished_ids or self._in_flight_fetches: # If there are unfinished IDs then the new position will be the @@ -844,10 +842,10 @@ class _AsyncCtxManagerWrapper(Generic[T]): async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], - ) -> Optional[bool]: + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool | None: return self.inner.__exit__(exc_type, exc, tb) @@ -857,10 +855,10 @@ class _MultiWriterCtxManager: id_gen: MultiWriterIdGenerator notifier: "ReplicationNotifier" - multiple_ids: Optional[int] = None + multiple_ids: int | None = None stream_ids: list[int] = attr.Factory(list) - async def __aenter__(self) -> Union[int, list[int]]: + async def __aenter__(self) -> int | list[int]: # It's safe to run this in autocommit mode as fetching values from a # sequence ignores transaction semantics anyway. self.stream_ids = await self.id_gen._db.runInteraction( @@ -877,9 +875,9 @@ class _MultiWriterCtxManager: async def __aexit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> bool: self.id_gen._mark_ids_as_finished(self.stream_ids) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index e2256aa109..5bee3cf34f 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -21,7 +21,7 @@ import abc import logging import threading -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable from synapse.storage.engines import ( BaseDatabaseEngine, @@ -71,7 +71,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta): db_conn: "LoggingDatabaseConnection", table: str, id_column: str, - stream_name: Optional[str] = None, + stream_name: str | None = None, positive: bool = True, ) -> None: """Should be called during start up to test that the current value of @@ -116,7 +116,7 @@ class PostgresSequenceGenerator(SequenceGenerator): db_conn: "LoggingDatabaseConnection", table: str, id_column: str, - stream_name: Optional[str] = None, + stream_name: str | None = None, positive: bool = True, ) -> None: """See SequenceGenerator.check_consistency for docstring.""" @@ -223,10 +223,10 @@ class LocalSequenceGenerator(SequenceGenerator): get_next_id_txn; should return the current maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. - self._callback: Optional[GetFirstCallbackType] = get_first_callback + self._callback: GetFirstCallbackType | None = get_first_callback # The current max value, or None if we haven't looked in the DB yet. - self._current_max_id: Optional[int] = None + self._current_max_id: int | None = None self._lock = threading.Lock() def get_next_id_txn(self, txn: Cursor) -> int: @@ -257,7 +257,7 @@ class LocalSequenceGenerator(SequenceGenerator): db_conn: Connection, table: str, id_column: str, - stream_name: Optional[str] = None, + stream_name: str | None = None, positive: bool = True, ) -> None: # There is nothing to do for in memory sequences @@ -278,9 +278,9 @@ def build_sequence_generator( database_engine: BaseDatabaseEngine, get_first_callback: GetFirstCallbackType, sequence_name: str, - table: Optional[str], - id_column: Optional[str], - stream_name: Optional[str] = None, + table: str | None, + id_column: str | None, + stream_name: str | None = None, positive: bool = True, ) -> SequenceGenerator: """Get the best impl of SequenceGenerator available diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index faf453b8a1..0d386e538e 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -19,7 +19,7 @@ # # from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar from synapse.types import StrCollection, UserID @@ -38,6 +38,6 @@ class EventSource(ABC, Generic[K, R]): limit: int, room_ids: StrCollection, is_guest: bool, - explicit_room_id: Optional[str] = None, + explicit_room_id: str | None = None, ) -> tuple[list[R], K]: raise NotImplementedError() diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 9fee5bfb92..52688a8b6b 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -19,7 +19,6 @@ # # import logging -from typing import Optional import attr @@ -40,8 +39,8 @@ MAX_LIMIT = 1000 class PaginationConfig: """A configuration object which stores pagination parameters.""" - from_token: Optional[StreamToken] - to_token: Optional[StreamToken] + from_token: StreamToken | None + to_token: StreamToken | None direction: Direction limit: int diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 08c976121a..0add391c65 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -10,16 +10,16 @@ # See the GNU Affero General Public License for more details: # . -from typing import Mapping, Optional +from typing import Mapping from synapse.types import JsonDict class EventInternalMetadata: def __init__(self, internal_metadata_dict: JsonDict): ... - stream_ordering: Optional[int] + stream_ordering: int | None """the stream ordering of this event. None, until it has been persisted.""" - instance_name: Optional[str] + instance_name: str | None """the instance name of the server that persisted this event. None, until it has been persisted.""" outlier: bool @@ -62,7 +62,7 @@ class EventInternalMetadata: (Added in synapse 0.99.0, so may be unreliable for events received before that) """ - def get_send_on_behalf_of(self) -> Optional[str]: + def get_send_on_behalf_of(self) -> str | None: """Whether this server should send the event on behalf of another server. This is used by the federation "send_join" API to forward the initial join event for a server in the room. diff --git a/synapse/synapse_rust/push.pyi b/synapse/synapse_rust/push.pyi index 1e135b8c69..9d8f0389e8 100644 --- a/synapse/synapse_rust/push.pyi +++ b/synapse/synapse_rust/push.pyi @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Mapping, Optional, Sequence, Union +from typing import Any, Collection, Mapping, Sequence from synapse.types import JsonDict, JsonValue @@ -25,7 +25,7 @@ class PushRule: @property def conditions(self) -> Sequence[Mapping[str, str]]: ... @property - def actions(self) -> Sequence[Union[Mapping[str, Any], str]]: ... + def actions(self) -> Sequence[Mapping[str, Any] | str]: ... @property def default(self) -> bool: ... @property @@ -61,7 +61,7 @@ class PushRuleEvaluator: flattened_keys: Mapping[str, JsonValue], has_mentions: bool, room_member_count: int, - sender_power_level: Optional[int], + sender_power_level: int | None, notification_power_levels: Mapping[str, int], related_events_flattened: Mapping[str, Mapping[str, JsonValue]], related_event_match_enabled: bool, @@ -73,14 +73,14 @@ class PushRuleEvaluator: def run( self, push_rules: FilteredPushRules, - user_id: Optional[str], - display_name: Optional[str], - msc4306_thread_subscription_state: Optional[bool], - ) -> Collection[Union[Mapping, str]]: ... + user_id: str | None, + display_name: str | None, + msc4306_thread_subscription_state: bool | None, + ) -> Collection[Mapping | str]: ... def matches( self, condition: JsonDict, - user_id: Optional[str], - display_name: Optional[str], - msc4306_thread_subscription_state: Optional[bool] = None, + user_id: str | None, + display_name: str | None, + msc4306_thread_subscription_state: bool | None = None, ) -> bool: ... diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 87436459ac..16892b37c0 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -85,8 +85,8 @@ MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. # A "simple" (canonical) JSON value. -SimpleJsonValue = Optional[Union[str, int, bool]] -JsonValue = Union[list[SimpleJsonValue], tuple[SimpleJsonValue, ...], SimpleJsonValue] +SimpleJsonValue = str | int | bool | None +JsonValue = list[SimpleJsonValue] | tuple[SimpleJsonValue, ...] | SimpleJsonValue # A JSON-serialisable dict. JsonDict = dict[str, Any] # A JSON-serialisable mapping; roughly speaking an immutable JSONDict. @@ -101,12 +101,12 @@ JsonSerializable = object # # StrCollection is an unordered collection of strings. If ordering is important, # StrSequence can be used instead. -StrCollection = Union[tuple[str, ...], list[str], AbstractSet[str]] +StrCollection = tuple[str, ...] | list[str] | AbstractSet[str] # Sequence[str] that does not include str itself; str being a Sequence[str] # is very misleading and results in bugs. # # Unlike StrCollection, StrSequence is an ordered collection of strings. -StrSequence = Union[tuple[str, ...], list[str]] +StrSequence = tuple[str, ...] | list[str] # Note that this seems to require inheriting *directly* from Interface in order @@ -158,11 +158,11 @@ class Requester: """ user: "UserID" - access_token_id: Optional[int] + access_token_id: int | None is_guest: bool scope: set[str] shadow_banned: bool - device_id: Optional[str] + device_id: str | None app_service: Optional["ApplicationService"] authenticated_entity: str @@ -216,13 +216,13 @@ class Requester: def create_requester( user_id: Union[str, "UserID"], - access_token_id: Optional[int] = None, + access_token_id: int | None = None, is_guest: bool = False, scope: StrCollection = (), shadow_banned: bool = False, - device_id: Optional[str] = None, + device_id: str | None = None, app_service: Optional["ApplicationService"] = None, - authenticated_entity: Optional[str] = None, + authenticated_entity: str | None = None, ) -> Requester: """ Create a new ``Requester`` object @@ -385,7 +385,7 @@ class RoomID: SIGIL = "!" id: str - room_id_with_domain: Optional[RoomIdWithDomain] + room_id_with_domain: RoomIdWithDomain | None @classmethod def is_valid(cls: type["RoomID"], s: str) -> bool: @@ -397,7 +397,7 @@ class RoomID: except Exception: return False - def get_domain(self) -> Optional[str]: + def get_domain(self) -> str | None: if not self.room_id_with_domain: return None return self.room_id_with_domain.domain @@ -419,7 +419,7 @@ class RoomID: Codes.INVALID_PARAM, ) - room_id_with_domain: Optional[RoomIdWithDomain] = None + room_id_with_domain: RoomIdWithDomain | None = None if ":" in s: room_id_with_domain = RoomIdWithDomain.from_string(s) else: @@ -487,7 +487,7 @@ NON_MXID_CHARACTER_PATTERN = re.compile( def map_username_to_mxid_localpart( - username: Union[str, bytes], case_sensitive: bool = False + username: str | bytes, case_sensitive: bool = False ) -> str: """Map a username onto a string suitable for a MXID @@ -744,7 +744,7 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): attributes, must be hashable. """ - topological: Optional[int] = attr.ib( + topological: int | None = attr.ib( validator=attr.validators.optional(attr.validators.instance_of(int)), kw_only=True, default=None, @@ -954,7 +954,7 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): def is_stream_position_in_range( low: Optional["AbstractMultiWriterStreamToken"], high: Optional["AbstractMultiWriterStreamToken"], - instance_name: Optional[str], + instance_name: str | None, pos: int, ) -> bool: """Checks if a given persisted position is between the two given tokens. @@ -1224,11 +1224,11 @@ class StreamToken: @overload def get_field( self, key: StreamKeyType - ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: ... + ) -> int | RoomStreamToken | MultiWriterStreamToken: ... def get_field( self, key: StreamKeyType - ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: + ) -> int | RoomStreamToken | MultiWriterStreamToken: """Returns the stream ID for the given key.""" return getattr(self, key.value) @@ -1394,8 +1394,8 @@ class PersistedEventPosition(PersistedPosition): @attr.s(slots=True, frozen=True, auto_attribs=True) class ThirdPartyInstanceID: - appservice_id: Optional[str] - network_id: Optional[str] + appservice_id: str | None + network_id: str | None # Deny iteration because it will bite you if you try to create a singleton # set by: @@ -1432,7 +1432,7 @@ class ReadReceipt: receipt_type: str user_id: str event_ids: list[str] - thread_id: Optional[str] + thread_id: str | None data: JsonDict @@ -1507,11 +1507,11 @@ class UserInfo: """ user_id: UserID - appservice_id: Optional[int] - consent_server_notice_sent: Optional[str] - consent_version: Optional[str] - consent_ts: Optional[int] - user_type: Optional[str] + appservice_id: int | None + consent_server_notice_sent: str | None + consent_version: str | None + consent_ts: int | None + user_type: str | None creation_ts: int is_admin: bool is_deactivated: bool @@ -1524,14 +1524,14 @@ class UserInfo: class UserProfile(TypedDict): user_id: str - display_name: Optional[str] - avatar_url: Optional[str] + display_name: str | None + avatar_url: str | None @attr.s(auto_attribs=True, frozen=True, slots=True) class RetentionPolicy: - min_lifetime: Optional[int] = None - max_lifetime: Optional[int] = None + min_lifetime: int | None = None + max_lifetime: int | None = None class TaskStatus(str, Enum): @@ -1563,13 +1563,13 @@ class ScheduledTask: # In milliseconds since epoch in system time timezone, usually UTC. timestamp: int # Optionally bind a task to some resource id for easy retrieval - resource_id: Optional[str] + resource_id: str | None # Optional parameters that will be passed to the function ran by the task - params: Optional[JsonMapping] + params: JsonMapping | None # Optional result that can be updated by the running task - result: Optional[JsonMapping] + result: JsonMapping | None # Optional error that should be assigned a value when the status is FAILED - error: Optional[str] + error: str | None @attr.s(auto_attribs=True, frozen=True, slots=True) diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index 80651bb685..b9d1b41a75 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -19,7 +19,7 @@ # -from typing import Optional, TypedDict +from typing import TypedDict from synapse.api.constants import EventTypes @@ -66,10 +66,10 @@ class ShutdownRoomParams(TypedDict): even if there are still users joined to the room. """ - requester_user_id: Optional[str] - new_room_user_id: Optional[str] - new_room_name: Optional[str] - message: Optional[str] + requester_user_id: str | None + new_room_user_id: str | None + new_room_name: str | None + message: str | None block: bool purge: bool force_purge: bool @@ -90,4 +90,4 @@ class ShutdownRoomResponse(TypedDict): kicked_users: list[str] failed_to_kick_users: list[str] local_aliases: list[str] - new_room_id: Optional[str] + new_room_id: str | None diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index c83b534e00..494e3570d0 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -25,7 +25,6 @@ from typing import ( Generic, Mapping, MutableMapping, - Optional, Sequence, TypeVar, cast, @@ -166,12 +165,12 @@ class SlidingSyncResult: @attr.s(slots=True, frozen=True, auto_attribs=True) class StrippedHero: user_id: str - display_name: Optional[str] - avatar_url: Optional[str] + display_name: str | None + avatar_url: str | None - name: Optional[str] - avatar: Optional[str] - heroes: Optional[list[StrippedHero]] + name: str | None + avatar: str | None + heroes: list[StrippedHero] | None is_dm: bool initial: bool unstable_expanded_timeline: bool @@ -179,18 +178,18 @@ class SlidingSyncResult: required_state: list[EventBase] # Should be empty for invite/knock rooms with `stripped_state` timeline_events: list[EventBase] - bundled_aggregations: Optional[dict[str, "BundledAggregations"]] + bundled_aggregations: dict[str, "BundledAggregations"] | None # Optional because it's only relevant to invite/knock rooms stripped_state: list[JsonDict] # Only optional because it won't be included for invite/knock rooms with `stripped_state` - prev_batch: Optional[StreamToken] + prev_batch: StreamToken | None # Only optional because it won't be included for invite/knock rooms with `stripped_state` - limited: Optional[bool] + limited: bool | None # Only optional because it won't be included for invite/knock rooms with `stripped_state` - num_live: Optional[int] - bump_stamp: Optional[int] - joined_count: Optional[int] - invited_count: Optional[int] + num_live: int | None + bump_stamp: int | None + joined_count: int | None + invited_count: int | None notification_count: int highlight_count: int @@ -281,7 +280,7 @@ class SlidingSyncResult: """ # Only present on incremental syncs - device_list_updates: Optional[DeviceListUpdates] + device_list_updates: DeviceListUpdates | None device_one_time_keys_count: Mapping[str, int] device_unused_fallback_key_types: Sequence[str] @@ -364,7 +363,7 @@ class SlidingSyncResult: @attr.s(slots=True, frozen=True, auto_attribs=True) class ThreadSubscription: # always present when `subscribed` - automatic: Optional[bool] + automatic: bool | None # the same as our stream_id; useful for clients to resolve # race conditions locally @@ -377,10 +376,10 @@ class SlidingSyncResult: bump_stamp: int # room_id -> event_id (of thread root) -> the subscription change - subscribed: Optional[Mapping[str, Mapping[str, ThreadSubscription]]] + subscribed: Mapping[str, Mapping[str, ThreadSubscription]] | None # room_id -> event_id (of thread root) -> the unsubscription - unsubscribed: Optional[Mapping[str, Mapping[str, ThreadUnsubscription]]] - prev_batch: Optional[ThreadSubscriptionsToken] + unsubscribed: Mapping[str, Mapping[str, ThreadUnsubscription]] | None + prev_batch: ThreadSubscriptionsToken | None def __bool__(self) -> bool: return ( @@ -389,12 +388,12 @@ class SlidingSyncResult: or bool(self.prev_batch) ) - to_device: Optional[ToDeviceExtension] = None - e2ee: Optional[E2eeExtension] = None - account_data: Optional[AccountDataExtension] = None - receipts: Optional[ReceiptsExtension] = None - typing: Optional[TypingExtension] = None - thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None + to_device: ToDeviceExtension | None = None + e2ee: E2eeExtension | None = None + account_data: AccountDataExtension | None = None + receipts: ReceiptsExtension | None = None + typing: TypingExtension | None = None + thread_subscriptions: ThreadSubscriptionsExtension | None = None def __bool__(self) -> bool: return bool( @@ -730,7 +729,7 @@ class HaveSentRoom(Generic[T]): """ status: HaveSentRoomFlag - last_token: Optional[T] + last_token: T | None @staticmethod def live() -> "HaveSentRoom[T]": diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 865c2ba532..49782b5234 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -18,8 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional, Union - from pydantic import ( ConfigDict, Field, @@ -49,8 +47,8 @@ class AuthenticationData(RequestBodyModel): model_config = ConfigDict(extra="allow") - session: Optional[StrictStr] = None - type: Optional[StrictStr] = None + session: StrictStr | None = None + type: StrictStr | None = None # See also assert_valid_client_secret() @@ -67,9 +65,9 @@ ClientSecretStr = Annotated[ class ThreepidRequestTokenBody(RequestBodyModel): client_secret: ClientSecretStr - id_server: Optional[StrictStr] = None - id_access_token: Optional[StrictStr] = None - next_link: Optional[StrictStr] = None + id_server: StrictStr | None = None + id_access_token: StrictStr | None = None + next_link: StrictStr | None = None send_attempt: StrictInt @model_validator(mode="after") @@ -246,17 +244,17 @@ class SlidingSyncBody(RequestBodyModel): list of favourite rooms again. """ - is_dm: Optional[StrictBool] = None - spaces: Optional[list[StrictStr]] = None - is_encrypted: Optional[StrictBool] = None - is_invite: Optional[StrictBool] = None - room_types: Optional[list[Union[StrictStr, None]]] = None - not_room_types: Optional[list[Union[StrictStr, None]]] = None - room_name_like: Optional[StrictStr] = None - tags: Optional[list[StrictStr]] = None - not_tags: Optional[list[StrictStr]] = None + is_dm: StrictBool | None = None + spaces: list[StrictStr] | None = None + is_encrypted: StrictBool | None = None + is_invite: StrictBool | None = None + room_types: list[StrictStr | None] | None = None + not_room_types: list[StrictStr | None] | None = None + room_name_like: StrictStr | None = None + tags: list[StrictStr] | None = None + not_tags: list[StrictStr] | None = None - ranges: Optional[ + ranges: ( list[ Annotated[ tuple[ @@ -266,9 +264,10 @@ class SlidingSyncBody(RequestBodyModel): Field(strict=False), ] ] - ] = None - slow_get_all_rooms: Optional[StrictBool] = False - filters: Optional[Filters] = None + | None + ) = None + slow_get_all_rooms: StrictBool | None = False + filters: Filters | None = None class RoomSubscription(CommonRoomParameters): pass @@ -291,15 +290,13 @@ class SlidingSyncBody(RequestBodyModel): since: The `next_batch` from the previous sync response """ - enabled: Optional[StrictBool] = False + enabled: StrictBool | None = False limit: StrictInt = 100 - since: Optional[StrictStr] = None + since: StrictStr | None = None @field_validator("since") @classmethod - def since_token_check( - cls, value: Optional[StrictStr] - ) -> Optional[StrictStr]: + def since_token_check(cls, value: StrictStr | None) -> StrictStr | None: # `since` comes in as an opaque string token but we know that it's just # an integer representing the position in the device inbox stream. We # want to pre-validate it to make sure it works fine in downstream code. @@ -322,7 +319,7 @@ class SlidingSyncBody(RequestBodyModel): enabled """ - enabled: Optional[StrictBool] = False + enabled: StrictBool | None = False class AccountDataExtension(RequestBodyModel): """The Account Data extension (MSC3959) @@ -335,11 +332,11 @@ class SlidingSyncBody(RequestBodyModel): extension to. """ - enabled: Optional[StrictBool] = False + enabled: StrictBool | None = False # Process all lists defined in the Sliding Window API. (This is the default.) - lists: Optional[list[StrictStr]] = ["*"] + lists: list[StrictStr] | None = ["*"] # Process all room subscriptions defined in the Room Subscription API. (This is the default.) - rooms: Optional[list[StrictStr]] = ["*"] + rooms: list[StrictStr] | None = ["*"] class ReceiptsExtension(RequestBodyModel): """The Receipts extension (MSC3960) @@ -352,11 +349,11 @@ class SlidingSyncBody(RequestBodyModel): extension to. """ - enabled: Optional[StrictBool] = False + enabled: StrictBool | None = False # Process all lists defined in the Sliding Window API. (This is the default.) - lists: Optional[list[StrictStr]] = ["*"] + lists: list[StrictStr] | None = ["*"] # Process all room subscriptions defined in the Room Subscription API. (This is the default.) - rooms: Optional[list[StrictStr]] = ["*"] + rooms: list[StrictStr] | None = ["*"] class TypingExtension(RequestBodyModel): """The Typing Notification extension (MSC3961) @@ -369,11 +366,11 @@ class SlidingSyncBody(RequestBodyModel): extension to. """ - enabled: Optional[StrictBool] = False + enabled: StrictBool | None = False # Process all lists defined in the Sliding Window API. (This is the default.) - lists: Optional[list[StrictStr]] = ["*"] + lists: list[StrictStr] | None = ["*"] # Process all room subscriptions defined in the Room Subscription API. (This is the default.) - rooms: Optional[list[StrictStr]] = ["*"] + rooms: list[StrictStr] | None = ["*"] class ThreadSubscriptionsExtension(RequestBodyModel): """The Thread Subscriptions extension (MSC4308) @@ -383,33 +380,34 @@ class SlidingSyncBody(RequestBodyModel): limit: maximum number of subscription changes to return (default 100) """ - enabled: Optional[StrictBool] = False + enabled: StrictBool | None = False limit: StrictInt = 100 - to_device: Optional[ToDeviceExtension] = None - e2ee: Optional[E2eeExtension] = None - account_data: Optional[AccountDataExtension] = None - receipts: Optional[ReceiptsExtension] = None - typing: Optional[TypingExtension] = None - thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field( + to_device: ToDeviceExtension | None = None + e2ee: E2eeExtension | None = None + account_data: AccountDataExtension | None = None + receipts: ReceiptsExtension | None = None + typing: TypingExtension | None = None + thread_subscriptions: ThreadSubscriptionsExtension | None = Field( None, alias="io.element.msc4308.thread_subscriptions" ) - conn_id: Optional[StrictStr] = None - lists: Optional[ + conn_id: StrictStr | None = None + lists: ( dict[ Annotated[str, StringConstraints(max_length=64, strict=True)], SlidingSyncList, ] - ] = None - room_subscriptions: Optional[dict[StrictStr, RoomSubscription]] = None - extensions: Optional[Extensions] = None + | None + ) = None + room_subscriptions: dict[StrictStr, RoomSubscription] | None = None + extensions: Extensions | None = None @field_validator("lists") @classmethod def lists_length_check( - cls, value: Optional[dict[str, SlidingSyncList]] - ) -> Optional[dict[str, SlidingSyncList]]: + cls, value: dict[str, SlidingSyncList] | None + ) -> dict[str, SlidingSyncList] | None: if value is not None: assert len(value) <= 100, f"Max lists: 100 but saw {len(value)}" return value diff --git a/synapse/types/state.py b/synapse/types/state.py index 1b4de61d3e..ab619a7fb8 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -27,7 +27,6 @@ from typing import ( Collection, Iterable, Mapping, - Optional, TypeVar, ) @@ -60,7 +59,7 @@ class StateFilter: appear in `types`. """ - types: "immutabledict[str, Optional[frozenset[str]]]" + types: "immutabledict[str, frozenset[str] | None]" include_others: bool = False def __attrs_post_init__(self) -> None: @@ -101,7 +100,7 @@ class StateFilter: return _NONE_STATE_FILTER @staticmethod - def from_types(types: Iterable[tuple[str, Optional[str]]]) -> "StateFilter": + def from_types(types: Iterable[tuple[str, str | None]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: @@ -111,7 +110,7 @@ class StateFilter: Returns: The new state filter. """ - type_dict: dict[str, Optional[set[str]]] = {} + type_dict: dict[str, set[str] | None] = {} for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -130,7 +129,7 @@ class StateFilter: ) ) - def to_types(self) -> Iterable[tuple[str, Optional[str]]]: + def to_types(self) -> Iterable[tuple[str, str | None]]: """The inverse to `from_types`.""" for event_type, state_keys in self.types.items(): if state_keys is None: @@ -157,13 +156,13 @@ class StateFilter: @staticmethod def freeze( - types: Mapping[str, Optional[Collection[str]]], include_others: bool + types: Mapping[str, Collection[str] | None], include_others: bool ) -> "StateFilter": """ Returns a (frozen) StateFilter with the same contents as the parameters specified here, which can be made of mutable types. """ - types_with_frozen_values: dict[str, Optional[frozenset[str]]] = {} + types_with_frozen_values: dict[str, frozenset[str] | None] = {} for state_types, state_keys in types.items(): if state_keys is not None: types_with_frozen_values[state_types] = frozenset(state_keys) @@ -289,7 +288,7 @@ class StateFilter: return where_clause, where_args - def max_entries_returned(self) -> Optional[int]: + def max_entries_returned(self) -> int | None: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -450,7 +449,7 @@ class StateFilter: # {state type -> set of state keys OR None for wildcard} # (The same structure as that of a StateFilter.) - new_types: dict[str, Optional[set[str]]] = {} + new_types: dict[str, set[str] | None] = {} # if we start with all, insert the excluded statetypes as empty sets # to prevent them from being included diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 0d3b7ca740..f937080f9e 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -25,7 +25,6 @@ import typing from typing import ( Iterator, Mapping, - Optional, Sequence, TypeVar, ) @@ -61,7 +60,7 @@ def unwrapFirstError(failure: Failure) -> Failure: def log_failure( failure: Failure, msg: str, consumeErrors: bool = True -) -> Optional[Failure]: +) -> Failure | None: """Creates a function suitable for passing to `Deferred.addErrback` that logs any failures that occur. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 8322a1bb33..825fb10acf 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -40,9 +40,7 @@ from typing import ( Hashable, Iterable, Literal, - Optional, TypeVar, - Union, overload, ) @@ -104,8 +102,8 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): __slots__ = ["_deferred", "_observers", "_result"] _deferred: "defer.Deferred[_T]" - _observers: Union[list["defer.Deferred[_T]"], tuple[()]] - _result: Union[None, tuple[Literal[True], _T], tuple[Literal[False], Failure]] + _observers: list["defer.Deferred[_T]"] | tuple[()] + _result: None | tuple[Literal[True], _T] | tuple[Literal[False], Failure] def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) @@ -132,7 +130,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): ) return r - def errback(f: Failure) -> Optional[Failure]: + def errback(f: Failure) -> Failure | None: object.__setattr__(self, "_result", (False, f)) # once we have set _result, no more entries will be added to _observers, @@ -187,7 +185,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True - def get_result(self) -> Union[_T, Failure]: + def get_result(self) -> _T | Failure: if self._result is None: raise ValueError(f"{self!r} has no result yet") return self._result[1] @@ -402,80 +400,78 @@ def gather_results( # type: ignore[misc] @overload async def gather_optional_coroutines( - *coroutines: Unpack[tuple[Optional[Coroutine[Any, Any, T1]]]], -) -> tuple[Optional[T1]]: ... + *coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None]], +) -> tuple[T1 | None]: ... @overload async def gather_optional_coroutines( *coroutines: Unpack[ tuple[ - Optional[Coroutine[Any, Any, T1]], - Optional[Coroutine[Any, Any, T2]], + Coroutine[Any, Any, T1] | None, + Coroutine[Any, Any, T2] | None, ] ], -) -> tuple[Optional[T1], Optional[T2]]: ... +) -> tuple[T1 | None, T2 | None]: ... @overload async def gather_optional_coroutines( *coroutines: Unpack[ tuple[ - Optional[Coroutine[Any, Any, T1]], - Optional[Coroutine[Any, Any, T2]], - Optional[Coroutine[Any, Any, T3]], + Coroutine[Any, Any, T1] | None, + Coroutine[Any, Any, T2] | None, + Coroutine[Any, Any, T3] | None, ] ], -) -> tuple[Optional[T1], Optional[T2], Optional[T3]]: ... +) -> tuple[T1 | None, T2 | None, T3 | None]: ... @overload async def gather_optional_coroutines( *coroutines: Unpack[ tuple[ - Optional[Coroutine[Any, Any, T1]], - Optional[Coroutine[Any, Any, T2]], - Optional[Coroutine[Any, Any, T3]], - Optional[Coroutine[Any, Any, T4]], + Coroutine[Any, Any, T1] | None, + Coroutine[Any, Any, T2] | None, + Coroutine[Any, Any, T3] | None, + Coroutine[Any, Any, T4] | None, ] ], -) -> tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ... +) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None]: ... @overload async def gather_optional_coroutines( *coroutines: Unpack[ tuple[ - Optional[Coroutine[Any, Any, T1]], - Optional[Coroutine[Any, Any, T2]], - Optional[Coroutine[Any, Any, T3]], - Optional[Coroutine[Any, Any, T4]], - Optional[Coroutine[Any, Any, T5]], + Coroutine[Any, Any, T1] | None, + Coroutine[Any, Any, T2] | None, + Coroutine[Any, Any, T3] | None, + Coroutine[Any, Any, T4] | None, + Coroutine[Any, Any, T5] | None, ] ], -) -> tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ... +) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None]: ... @overload async def gather_optional_coroutines( *coroutines: Unpack[ tuple[ - Optional[Coroutine[Any, Any, T1]], - Optional[Coroutine[Any, Any, T2]], - Optional[Coroutine[Any, Any, T3]], - Optional[Coroutine[Any, Any, T4]], - Optional[Coroutine[Any, Any, T5]], - Optional[Coroutine[Any, Any, T6]], + Coroutine[Any, Any, T1] | None, + Coroutine[Any, Any, T2] | None, + Coroutine[Any, Any, T3] | None, + Coroutine[Any, Any, T4] | None, + Coroutine[Any, Any, T5] | None, + Coroutine[Any, Any, T6] | None, ] ], -) -> tuple[ - Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5], Optional[T6] -]: ... +) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ... async def gather_optional_coroutines( - *coroutines: Unpack[tuple[Optional[Coroutine[Any, Any, T1]], ...]], -) -> tuple[Optional[T1], ...]: + *coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]], +) -> tuple[T1 | None, ...]: """Helper function that allows waiting on multiple coroutines at once. The return value is a tuple of the return values of the coroutines in order. @@ -866,7 +862,7 @@ class DoneAwaitable(Awaitable[R]): return self.value -def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: +def maybe_awaitable(value: Awaitable[R] | R) -> Awaitable[R]: """Convert a value to an awaitable if not already an awaitable.""" if inspect.isawaitable(value): return value diff --git a/synapse/util/background_queue.py b/synapse/util/background_queue.py index 7e4c322662..93ffd9f271 100644 --- a/synapse/util/background_queue.py +++ b/synapse/util/background_queue.py @@ -21,7 +21,6 @@ from typing import ( Awaitable, Callable, Generic, - Optional, TypeVar, ) @@ -76,7 +75,7 @@ class BackgroundQueue(Generic[T]): # Indicates if a background process is running, and if so whether there # is new data in the queue. Used to signal to an existing background # process that there is new data added to the queue. - self._wakeup_event: Optional[DeferredEvent] = None + self._wakeup_event: DeferredEvent | None = None def add(self, item: T) -> None: """Add an item into the queue.""" diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index c799fca550..a65ab7f57d 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -24,7 +24,7 @@ import logging import typing from enum import Enum, auto from sys import intern -from typing import Any, Callable, Optional, Sized, TypeVar +from typing import Any, Callable, Sized, TypeVar import attr from prometheus_client import REGISTRY @@ -129,7 +129,7 @@ class CacheMetric: _cache: Sized _cache_type: str _cache_name: str - _collect_callback: Optional[Callable] + _collect_callback: Callable | None _server_name: str hits: int = 0 @@ -137,7 +137,7 @@ class CacheMetric: eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib( factory=collections.Counter ) - memory_usage: Optional[int] = None + memory_usage: int | None = None def inc_hits(self) -> None: self.hits += 1 @@ -208,9 +208,9 @@ def register_cache( cache_name: str, cache: Sized, server_name: str, - collect_callback: Optional[Callable] = None, + collect_callback: Callable | None = None, resizable: bool = True, - resize_callback: Optional[Callable] = None, + resize_callback: Callable | None = None, ) -> CacheMetric: """Register a cache object for metric collection and resizing. @@ -269,7 +269,7 @@ KNOWN_KEYS = { ) } -T = TypeVar("T", Optional[str], str) +T = TypeVar("T", str | None, str) def intern_string(string: T) -> T: diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index 9b86017cd9..491e7e52a1 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -19,7 +19,7 @@ # # import enum -from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union +from typing import Awaitable, Callable, Generic, TypeVar from twisted.internet.defer import Deferred from twisted.python.failure import Failure @@ -74,9 +74,9 @@ class CachedCall(Generic[TV]): f: The underlying function. Only one call to this function will be alive at once (per instance of CachedCall) """ - self._callable: Optional[Callable[[], Awaitable[TV]]] = f - self._deferred: Optional[Deferred] = None - self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel + self._callable: Callable[[], Awaitable[TV]] | None = f + self._deferred: Deferred | None = None + self._result: _Sentinel | TV | Failure = _Sentinel.sentinel async def get(self) -> TV: """Kick off the call if necessary, and return the result""" @@ -93,7 +93,7 @@ class CachedCall(Generic[TV]): # result in the deferred, since `awaiting` a deferred destroys its result. # (Also, if it's a Failure, GCing the deferred would log a critical error # about unhandled Failures) - def got_result(r: Union[TV, Failure]) -> None: + def got_result(r: TV | Failure) -> None: self._result = r self._deferred.addBoth(got_result) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 380f2a78ca..a1601cd4e9 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -31,7 +31,6 @@ from typing import ( Optional, Sized, TypeVar, - Union, cast, ) @@ -107,9 +106,9 @@ class DeferredCache(Generic[KT, VT]): cache_type = TreeCache if tree else dict # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache: Union[ - TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]" - ] = cache_type() + self._pending_deferred_cache: ( + TreeCache | "MutableMapping[KT, CacheEntry[KT, VT]]" + ) = cache_type() def metrics_cb() -> None: cache_pending_metric.labels( @@ -136,7 +135,7 @@ class DeferredCache(Generic[KT, VT]): prune_unread_entries=prune_unread_entries, ) - self.thread: Optional[threading.Thread] = None + self.thread: threading.Thread | None = None @property def max_entries(self) -> int: @@ -155,7 +154,7 @@ class DeferredCache(Generic[KT, VT]): def get( self, key: KT, - callback: Optional[Callable[[], None]] = None, + callback: Callable[[], None] | None = None, update_metrics: bool = True, ) -> defer.Deferred: """Looks the key up in the caches. @@ -199,7 +198,7 @@ class DeferredCache(Generic[KT, VT]): def get_bulk( self, keys: Collection[KT], - callback: Optional[Callable[[], None]] = None, + callback: Callable[[], None] | None = None, ) -> tuple[dict[KT, VT], Optional["defer.Deferred[dict[KT, VT]]"], Collection[KT]]: """Bulk lookup of items in the cache. @@ -263,9 +262,7 @@ class DeferredCache(Generic[KT, VT]): return (cached, pending_deferred, missing) - def get_immediate( - self, key: KT, default: T, update_metrics: bool = True - ) -> Union[VT, T]: + def get_immediate(self, key: KT, default: T, update_metrics: bool = True) -> VT | T: """If we have a *completed* cached value, return it.""" return self.cache.get(key, default, update_metrics=update_metrics) @@ -273,7 +270,7 @@ class DeferredCache(Generic[KT, VT]): self, key: KT, value: "defer.Deferred[VT]", - callback: Optional[Callable[[], None]] = None, + callback: Callable[[], None] | None = None, ) -> defer.Deferred: """Adds a new entry to the cache (or updates an existing one). @@ -328,7 +325,7 @@ class DeferredCache(Generic[KT, VT]): def start_bulk_input( self, keys: Collection[KT], - callback: Optional[Callable[[], None]] = None, + callback: Callable[[], None] | None = None, ) -> "CacheMultipleEntries[KT, VT]": """Bulk set API for use when fetching multiple keys at once from the DB. @@ -382,7 +379,7 @@ class DeferredCache(Generic[KT, VT]): return failure def prefill( - self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None + self, key: KT, value: VT, callback: Callable[[], None] | None = None ) -> None: callbacks = (callback,) if callback else () self.cache.set(key, value, callbacks=callbacks) @@ -435,7 +432,7 @@ class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta): @abc.abstractmethod def add_invalidation_callback( - self, key: KT, callback: Optional[Callable[[], None]] + self, key: KT, callback: Callable[[], None] | None ) -> None: """Add an invalidation callback""" ... @@ -461,7 +458,7 @@ class CacheEntrySingle(CacheEntry[KT, VT]): return self._deferred.observe() def add_invalidation_callback( - self, key: KT, callback: Optional[Callable[[], None]] + self, key: KT, callback: Callable[[], None] | None ) -> None: if callback is None: return @@ -478,7 +475,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]): __slots__ = ["_deferred", "_callbacks", "_global_callbacks"] def __init__(self) -> None: - self._deferred: Optional[ObservableDeferred[dict[KT, VT]]] = None + self._deferred: ObservableDeferred[dict[KT, VT]] | None = None self._callbacks: dict[KT, set[Callable[[], None]]] = {} self._global_callbacks: set[Callable[[], None]] = set() @@ -488,7 +485,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]): return self._deferred.observe().addCallback(lambda res: res[key]) def add_invalidation_callback( - self, key: KT, callback: Optional[Callable[[], None]] + self, key: KT, callback: Callable[[], None] | None ) -> None: if callback is None: return @@ -499,7 +496,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]): return self._callbacks.get(key, set()) | self._global_callbacks def add_global_invalidation_callback( - self, callback: Optional[Callable[[], None]] + self, callback: Callable[[], None] | None ) -> None: """Add a callback for when any keys get invalidated.""" if callback is None: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 7cc83bad37..fd931cac89 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -30,11 +30,9 @@ from typing import ( Hashable, Iterable, Mapping, - Optional, Protocol, Sequence, TypeVar, - Union, cast, ) from weakref import WeakValueDictionary @@ -53,7 +51,7 @@ from synapse.util.clock import Clock logger = logging.getLogger(__name__) -CacheKey = Union[tuple, Any] +CacheKey = tuple | Any F = TypeVar("F", bound=Callable[..., Any]) @@ -76,10 +74,10 @@ class _CacheDescriptorBase: def __init__( self, orig: Callable[..., Any], - num_args: Optional[int], - uncached_args: Optional[Collection[str]] = None, + num_args: int | None, + uncached_args: Collection[str] | None = None, cache_context: bool = False, - name: Optional[str] = None, + name: str | None = None, ): self.orig = orig self.name = name or orig.__name__ @@ -216,13 +214,13 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): *, orig: Callable[..., Any], max_entries: int = 1000, - num_args: Optional[int] = None, - uncached_args: Optional[Collection[str]] = None, + num_args: int | None = None, + uncached_args: Collection[str] | None = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, - name: Optional[str] = None, + name: str | None = None, ): super().__init__( orig, @@ -243,7 +241,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): self.prune_unread_entries = prune_unread_entries def __get__( - self, obj: Optional[HasServerNameAndClock], owner: Optional[type] + self, obj: HasServerNameAndClock | None, owner: type | None ) -> Callable[..., "defer.Deferred[Any]"]: # We need access to instance-level `obj.server_name` attribute assert obj is not None, ( @@ -331,8 +329,8 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): orig: Callable[..., Awaitable[dict]], cached_method_name: str, list_name: str, - num_args: Optional[int] = None, - name: Optional[str] = None, + num_args: int | None = None, + name: str | None = None, ): """ Args: @@ -359,7 +357,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): ) def __get__( - self, obj: Optional[Any], objtype: Optional[type] = None + self, obj: Any | None, objtype: type | None = None ) -> Callable[..., "defer.Deferred[dict[Hashable, Any]]"]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache @@ -471,7 +469,7 @@ class _CacheContext: on a lower level. """ - Cache = Union[DeferredCache, LruCache] + Cache = DeferredCache | LruCache _cache_context_objects: """WeakValueDictionary[ tuple["_CacheContext.Cache", CacheKey], "_CacheContext" @@ -508,13 +506,13 @@ class _CachedFunctionDescriptor: plugin.""" max_entries: int - num_args: Optional[int] - uncached_args: Optional[Collection[str]] + num_args: int | None + uncached_args: Collection[str] | None tree: bool cache_context: bool iterable: bool prune_unread_entries: bool - name: Optional[str] + name: str | None def __call__(self, orig: F) -> CachedFunction[F]: d = DeferredCacheDescriptor( @@ -534,13 +532,13 @@ class _CachedFunctionDescriptor: def cached( *, max_entries: int = 1000, - num_args: Optional[int] = None, - uncached_args: Optional[Collection[str]] = None, + num_args: int | None = None, + uncached_args: Collection[str] | None = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, - name: Optional[str] = None, + name: str | None = None, ) -> _CachedFunctionDescriptor: return _CachedFunctionDescriptor( max_entries=max_entries, @@ -561,8 +559,8 @@ class _CachedListFunctionDescriptor: cached_method_name: str list_name: str - num_args: Optional[int] = None - name: Optional[str] = None + num_args: int | None = None + name: str | None = None def __call__(self, orig: F) -> CachedFunction[F]: d = DeferredCacheListDescriptor( @@ -579,8 +577,8 @@ def cachedList( *, cached_method_name: str, list_name: str, - num_args: Optional[int] = None, - name: Optional[str] = None, + num_args: int | None = None, + name: str | None = None, ) -> _CachedListFunctionDescriptor: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index dd6f413e79..4289e327af 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -25,9 +25,7 @@ from typing import ( Generic, Iterable, Literal, - Optional, TypeVar, - Union, ) import attr @@ -88,7 +86,7 @@ class _PerKeyValue(Generic[DV]): __slots__ = ["value"] - def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None: + def __init__(self, value: DV | Literal[_Sentinel.sentinel]) -> None: self.value = value def __len__(self) -> int: @@ -157,8 +155,8 @@ class DictionaryCache(Generic[KT, DKT, DV]): # * A key of `(KT, DKT)` has a value of `_PerKeyValue` # * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]` self.cache: LruCache[ - tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]], - Union[_PerKeyValue, dict[DKT, DV]], + tuple[KT, DKT | Literal[_FullCacheKey.KEY]], + _PerKeyValue | dict[DKT, DV], ] = LruCache( max_size=max_entries, clock=clock, @@ -170,7 +168,7 @@ class DictionaryCache(Generic[KT, DKT, DV]): self.name = name self.sequence = 0 - self.thread: Optional[threading.Thread] = None + self.thread: threading.Thread | None = None def check_thread(self) -> None: expected_thread = self.thread @@ -182,9 +180,7 @@ class DictionaryCache(Generic[KT, DKT, DV]): "Cache objects can only be accessed from the main thread" ) - def get( - self, key: KT, dict_keys: Optional[Iterable[DKT]] = None - ) -> DictionaryEntry: + def get(self, key: KT, dict_keys: Iterable[DKT] | None = None) -> DictionaryEntry: """Fetch an entry out of the cache Args: @@ -295,7 +291,7 @@ class DictionaryCache(Generic[KT, DKT, DV]): sequence: int, key: KT, value: dict[DKT, DV], - fetched_keys: Optional[Iterable[DKT]] = None, + fetched_keys: Iterable[DKT] | None = None, ) -> None: """Updates the entry in the cache. diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 29ce6c0a77..528e4bb852 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -27,9 +27,7 @@ from typing import ( Generic, Iterable, Literal, - Optional, TypeVar, - Union, overload, ) @@ -146,7 +144,7 @@ class ExpiringCache(Generic[KT, VT]): return entry.value - def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: + def pop(self, key: KT, default: T = SENTINEL) -> VT | T: """Removes and returns the value with the given key from the cache. If the key isn't in the cache then `default` will be returned if @@ -173,12 +171,12 @@ class ExpiringCache(Generic[KT, VT]): return key in self._cache @overload - def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: ... + def get(self, key: KT, default: Literal[None] = None) -> VT | None: ... @overload - def get(self, key: KT, default: T) -> Union[VT, T]: ... + def get(self, key: KT, default: T) -> VT | T: ... - def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]: + def get(self, key: KT, default: T | None = None) -> VT | T | None: try: return self[key] except KeyError: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 04549ab65f..d304e804e9 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -33,9 +33,7 @@ from typing import ( Generic, Iterable, Literal, - Optional, TypeVar, - Union, cast, overload, ) @@ -117,14 +115,14 @@ def _expire_old_entries( hs: "HomeServer", clock: Clock, expiry_seconds: float, - autotune_config: Optional[dict], + autotune_config: dict | None, ) -> "defer.Deferred[None]": """Walks the global cache list to find cache entries that haven't been accessed in the given number of seconds, or if a given memory threshold has been breached. """ async def _internal_expire_old_entries( - clock: Clock, expiry_seconds: float, autotune_config: Optional[dict] + clock: Clock, expiry_seconds: float, autotune_config: dict | None ) -> None: if autotune_config: max_cache_memory_usage = autotune_config["max_cache_memory_usage"] @@ -281,7 +279,7 @@ class _Node(Generic[KT, VT]): prune_unread_entries: bool = True, ): self._list_node = ListNode.insert_after(self, root) - self._global_list_node: Optional[_TimedListNode] = None + self._global_list_node: _TimedListNode | None = None if USE_GLOBAL_LIST and prune_unread_entries: self._global_list_node = _TimedListNode.insert_after(self, GLOBAL_ROOT) self._global_list_node.update_last_access(clock) @@ -303,7 +301,7 @@ class _Node(Generic[KT, VT]): # footprint down. Storing `None` is free as its a singleton, while empty # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive # thing and used sets). - self.callbacks: Optional[list[Callable[[], None]]] = None + self.callbacks: list[Callable[[], None]] | None = None self.add_callbacks(callbacks) @@ -399,12 +397,12 @@ class LruCache(Generic[KT, VT]): clock: Clock, server_name: str, cache_name: str, - cache_type: type[Union[dict, TreeCache]] = dict, - size_callback: Optional[Callable[[VT], int]] = None, - metrics_collection_callback: Optional[Callable[[], None]] = None, + cache_type: type[dict | TreeCache] = dict, + size_callback: Callable[[VT], int] | None = None, + metrics_collection_callback: Callable[[], None] | None = None, apply_cache_factor_from_config: bool = True, prune_unread_entries: bool = True, - extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, + extra_index_cb: Callable[[KT, VT], KT] | None = None, ): ... @overload @@ -415,12 +413,12 @@ class LruCache(Generic[KT, VT]): clock: Clock, server_name: str, cache_name: Literal[None] = None, - cache_type: type[Union[dict, TreeCache]] = dict, - size_callback: Optional[Callable[[VT], int]] = None, - metrics_collection_callback: Optional[Callable[[], None]] = None, + cache_type: type[dict | TreeCache] = dict, + size_callback: Callable[[VT], int] | None = None, + metrics_collection_callback: Callable[[], None] | None = None, apply_cache_factor_from_config: bool = True, prune_unread_entries: bool = True, - extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, + extra_index_cb: Callable[[KT, VT], KT] | None = None, ): ... def __init__( @@ -429,13 +427,13 @@ class LruCache(Generic[KT, VT]): max_size: int, clock: Clock, server_name: str, - cache_name: Optional[str] = None, - cache_type: type[Union[dict, TreeCache]] = dict, - size_callback: Optional[Callable[[VT], int]] = None, - metrics_collection_callback: Optional[Callable[[], None]] = None, + cache_name: str | None = None, + cache_type: type[dict | TreeCache] = dict, + size_callback: Callable[[VT], int] | None = None, + metrics_collection_callback: Callable[[], None] | None = None, apply_cache_factor_from_config: bool = True, prune_unread_entries: bool = True, - extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, + extra_index_cb: Callable[[KT, VT], KT] | None = None, ): """ Args: @@ -484,7 +482,7 @@ class LruCache(Generic[KT, VT]): Note: The new key does not have to be unique. """ - cache: Union[dict[KT, _Node[KT, VT]], TreeCache] = cache_type() + cache: dict[KT, _Node[KT, VT]] | TreeCache = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -500,10 +498,10 @@ class LruCache(Generic[KT, VT]): # register_cache might call our "set_cache_factor" callback; there's nothing to # do yet when we get resized. - self._on_resize: Optional[Callable[[], None]] = None + self._on_resize: Callable[[], None] | None = None if cache_name is not None and server_name is not None: - metrics: Optional[CacheMetric] = register_cache( + metrics: CacheMetric | None = register_cache( cache_type="lru_cache", cache_name=cache_name, cache=self, @@ -625,7 +623,7 @@ class LruCache(Generic[KT, VT]): callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., update_last_access: bool = ..., - ) -> Optional[VT]: ... + ) -> VT | None: ... @overload def cache_get( @@ -634,16 +632,16 @@ class LruCache(Generic[KT, VT]): callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., update_last_access: bool = ..., - ) -> Union[T, VT]: ... + ) -> T | VT: ... @synchronized def cache_get( key: KT, - default: Optional[T] = None, + default: T | None = None, callbacks: Collection[Callable[[], None]] = (), update_metrics: bool = True, update_last_access: bool = True, - ) -> Union[None, T, VT]: + ) -> None | T | VT: """Look up a key in the cache Args: @@ -677,21 +675,21 @@ class LruCache(Generic[KT, VT]): key: tuple, default: Literal[None] = None, update_metrics: bool = True, - ) -> Union[None, Iterable[tuple[KT, VT]]]: ... + ) -> None | Iterable[tuple[KT, VT]]: ... @overload def cache_get_multi( key: tuple, default: T, update_metrics: bool = True, - ) -> Union[T, Iterable[tuple[KT, VT]]]: ... + ) -> T | Iterable[tuple[KT, VT]]: ... @synchronized def cache_get_multi( key: tuple, - default: Optional[T] = None, + default: T | None = None, update_metrics: bool = True, - ) -> Union[None, T, Iterable[tuple[KT, VT]]]: + ) -> None | T | Iterable[tuple[KT, VT]]: """Returns a generator yielding all entries under the given key. Can only be used if backed by a tree cache. @@ -769,13 +767,13 @@ class LruCache(Generic[KT, VT]): return value @overload - def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: ... + def cache_pop(key: KT, default: Literal[None] = None) -> VT | None: ... @overload - def cache_pop(key: KT, default: T) -> Union[T, VT]: ... + def cache_pop(key: KT, default: T) -> T | VT: ... @synchronized - def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: + def cache_pop(key: KT, default: T | None = None) -> None | T | VT: node = cache.get(key, None) if node: evicted_len = delete_node(node) @@ -925,22 +923,22 @@ class AsyncLruCache(Generic[KT, VT]): self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs) async def get( - self, key: KT, default: Optional[T] = None, update_metrics: bool = True - ) -> Optional[VT]: + self, key: KT, default: T | None = None, update_metrics: bool = True + ) -> VT | None: return self._lru_cache.get(key, update_metrics=update_metrics) async def get_external( self, key: KT, - default: Optional[T] = None, + default: T | None = None, update_metrics: bool = True, - ) -> Optional[VT]: + ) -> VT | None: # This method should fetch from any configured external cache, in this case noop. return None def get_local( - self, key: KT, default: Optional[T] = None, update_metrics: bool = True - ) -> Optional[VT]: + self, key: KT, default: T | None = None, update_metrics: bool = True + ) -> VT | None: return self._lru_cache.get(key, update_metrics=update_metrics) async def set(self, key: KT, value: VT) -> None: diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index e82036d7e0..b1cdc81dda 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -26,7 +26,6 @@ from typing import ( Callable, Generic, Iterable, - Optional, TypeVar, ) @@ -88,7 +87,7 @@ class ResponseCacheEntry: easier to cache Failure results. """ - opentracing_span_context: "Optional[opentracing.SpanContext]" + opentracing_span_context: "opentracing.SpanContext | None" """The opentracing span which generated/is generating the result""" @@ -150,7 +149,7 @@ class ResponseCache(Generic[KV]): """ return self._result_cache.keys() - def _get(self, key: KV) -> Optional[ResponseCacheEntry]: + def _get(self, key: KV) -> ResponseCacheEntry | None: """Look up the given key. Args: @@ -171,7 +170,7 @@ class ResponseCache(Generic[KV]): self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]", - opentracing_span_context: "Optional[opentracing.SpanContext]", + opentracing_span_context: "opentracing.SpanContext | None", ) -> ResponseCacheEntry: """Set the entry for the given key to the given deferred. @@ -289,7 +288,7 @@ class ResponseCache(Generic[KV]): if cache_context: kwargs["cache_context"] = context - span_context: Optional[opentracing.SpanContext] = None + span_context: opentracing.SpanContext | None = None async def cb() -> RV: # NB it is important that we do not `await` before setting span_context! diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 552570fbb9..7c6c9bc572 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -21,7 +21,7 @@ import logging import math -from typing import Collection, Mapping, Optional, Union +from typing import Collection, Mapping import attr from sortedcontainers import SortedDict @@ -45,7 +45,7 @@ class AllEntitiesChangedResult: that callers do the correct checks. """ - _entities: Optional[list[EntityType]] + _entities: list[EntityType] | None @property def hit(self) -> bool: @@ -78,7 +78,7 @@ class StreamChangeCache: server_name: str, current_stream_pos: int, max_size: int = 10000, - prefilled_cache: Optional[Mapping[EntityType, int]] = None, + prefilled_cache: Mapping[EntityType, int] | None = None, ) -> None: """ Args: @@ -182,7 +182,7 @@ class StreamChangeCache: def get_entities_changed( self, entities: Collection[EntityType], stream_pos: int, _perf_factor: int = 1 - ) -> Union[set[EntityType], frozenset[EntityType]]: + ) -> set[EntityType] | frozenset[EntityType]: """ Returns the subset of the given entities that have had changes after the given position. @@ -352,7 +352,7 @@ class StreamChangeCache: for entity in r: self._entity_to_key.pop(entity, None) - def get_max_pos_of_last_change(self, entity: EntityType) -> Optional[int]: + def get_max_pos_of_last_change(self, entity: EntityType) -> int | None: """Returns an upper bound of the stream id of the last change to an entity. diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 2be9463d6a..25b87832d8 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -21,7 +21,7 @@ import logging import time -from typing import Any, Callable, Generic, TypeVar, Union +from typing import Any, Callable, Generic, TypeVar import attr from sortedcontainers import SortedList @@ -91,7 +91,7 @@ class TTLCache(Generic[KT, VT]): self._data[key] = entry self._expiry_list.add(entry) - def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: + def get(self, key: KT, default: T = SENTINEL) -> VT | T: """Get a value from the cache Args: @@ -134,7 +134,7 @@ class TTLCache(Generic[KT, VT]): self._metrics.inc_hits() return e.value, e.expiry_time, e.ttl - def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: + def pop(self, key: KT, default: T = SENTINEL) -> VT | T: """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 715240c8ce..7e92b55592 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -28,7 +28,7 @@ require. But this is probably just symptomatic of Python's package management. import logging from importlib import metadata -from typing import Any, Iterable, NamedTuple, Optional, Sequence, cast +from typing import Any, Iterable, NamedTuple, Sequence, cast from packaging.markers import Marker, Value, Variable, default_environment from packaging.requirements import Requirement @@ -153,7 +153,7 @@ def _values_from_marker_value(value: Value) -> set[str]: return {str(raw)} -def _extras_from_marker(marker: Optional[Marker]) -> set[str]: +def _extras_from_marker(marker: Marker | None) -> set[str]: """Return every `extra` referenced in the supplied marker tree.""" extras: set[str] = set() @@ -214,7 +214,7 @@ def _marker_applies_for_any_extra(requirement: Requirement, extras: set[str]) -> ) -def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str: +def _not_installed(requirement: Requirement, extra: str | None = None) -> str: if extra: return ( f"Synapse {VERSION} needs {requirement.name} for {extra}, " @@ -225,7 +225,7 @@ def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str def _incorrect_version( - requirement: Requirement, got: str, extra: Optional[str] = None + requirement: Requirement, got: str, extra: str | None = None ) -> str: if extra: return ( @@ -238,7 +238,7 @@ def _incorrect_version( ) -def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str: +def _no_reported_version(requirement: Requirement, extra: str | None = None) -> str: if extra: return ( f"Synapse {VERSION} needs {requirement} for {extra}, " @@ -251,7 +251,7 @@ def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) ) -def check_requirements(extra: Optional[str] = None) -> None: +def check_requirements(extra: str | None = None) -> None: """Check Synapse's dependencies are present and correctly versioned. If provided, `extra` must be the name of an packaging extra (e.g. "saml2" in diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index 411b47f939..63e0571a78 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -27,7 +27,7 @@ import os import signal import sys from types import FrameType, TracebackType -from typing import NoReturn, Optional +from typing import NoReturn from synapse.logging.context import ( LoggingContext, @@ -121,7 +121,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - def excepthook( type_: type[BaseException], value: BaseException, - traceback: Optional[TracebackType], + traceback: TracebackType | None, ) -> None: logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) @@ -144,7 +144,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - sys.exit(1) # write a log line on SIGTERM. - def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn: + def sigterm(signum: int, frame: FrameType | None) -> NoReturn: logger.warning("Caught signal %s. Stopping daemon.", signum) sys.exit(0) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e8df5399cd..23ef67c752 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -25,9 +25,7 @@ from typing import ( Awaitable, Callable, Generic, - Optional, TypeVar, - Union, ) from typing_extensions import ParamSpec @@ -137,7 +135,7 @@ class Signal(Generic[P]): Returns a Deferred that will complete when all the observers have completed.""" - async def do(observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]: + async def do(observer: Callable[P, R | Awaitable[R]]) -> R | None: try: return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: diff --git a/synapse/util/events.py b/synapse/util/events.py index 4a1aa28ce4..19eca1c1ae 100644 --- a/synapse/util/events.py +++ b/synapse/util/events.py @@ -13,7 +13,7 @@ # # -from typing import Any, Optional +from typing import Any from pydantic import Field, StrictStr, ValidationError, field_validator @@ -41,7 +41,7 @@ class MTextRepresentation(ParseModel): """ body: StrictStr - mimetype: Optional[StrictStr] = None + mimetype: StrictStr | None = None class MTopic(ParseModel): @@ -53,7 +53,7 @@ class MTopic(ParseModel): See `TopicContentBlock` in the Matrix specification. """ - m_text: Optional[list[MTextRepresentation]] = Field(None, alias="m.text") + m_text: list[MTextRepresentation] | None = Field(None, alias="m.text") """ An ordered array of textual representations in different mimetypes. """ @@ -65,7 +65,7 @@ class MTopic(ParseModel): @classmethod def ignore_invalid_representations( cls, m_text: Any - ) -> Optional[list[MTextRepresentation]]: + ) -> list[MTextRepresentation] | None: if not isinstance(m_text, (list, tuple)): raise ValueError("m.text must be a list or a tuple") representations = [] @@ -87,7 +87,7 @@ class TopicContent(ParseModel): The topic in plain text. """ - m_topic: Optional[MTopic] = Field(None, alias="m.topic") + m_topic: MTopic | None = Field(None, alias="m.topic") """ Textual representation of the room topic in different mimetypes. """ @@ -96,14 +96,14 @@ class TopicContent(ParseModel): # `topic` field. @field_validator("m_topic", mode="before") @classmethod - def ignore_invalid_m_topic(cls, m_topic: Any) -> Optional[MTopic]: + def ignore_invalid_m_topic(cls, m_topic: Any) -> MTopic | None: try: return MTopic.model_validate(m_topic) except ValidationError: return None -def get_plain_text_topic_from_event_content(content: JsonDict) -> Optional[str]: +def get_plain_text_topic_from_event_content(content: JsonDict) -> str | None: """ Given the `content` of an `m.room.topic` event, returns the plain-text topic representation. Prefers pulling plain-text from the newer `m.topic` field if diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 9fa8d40234..8d64684084 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -19,7 +19,7 @@ # import queue -from typing import Any, BinaryIO, Optional, Union, cast +from typing import Any, BinaryIO, cast from twisted.internet import threads from twisted.internet.defer import Deferred @@ -50,7 +50,7 @@ class BackgroundFileConsumer: self._reactor: ISynapseReactor = reactor # Producer we're registered with - self._producer: Optional[Union[IPushProducer, IPullProducer]] = None + self._producer: IPushProducer | IPullProducer | None = None # True if PushProducer, false if PullProducer self.streaming = False @@ -61,18 +61,18 @@ class BackgroundFileConsumer: # Queue of slices of bytes to be written. When producer calls # unregister a final None is sent. - self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() + self._bytes_queue: queue.Queue[bytes | None] = queue.Queue() # Deferred that is resolved when finished writing # # This is really Deferred[None], but mypy doesn't seem to like that. - self._finished_deferred: Optional[Deferred[Any]] = None + self._finished_deferred: Deferred[Any] | None = None # If the _writer thread throws an exception it gets stored here. - self._write_exception: Optional[Exception] = None + self._write_exception: Exception | None = None def registerProducer( - self, producer: Union[IPushProducer, IPullProducer], streaming: bool + self, producer: IPushProducer | IPullProducer, streaming: bool ) -> None: """Part of IConsumer interface diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py index e07003f1af..f40de8dcc2 100644 --- a/synapse/util/gai_resolver.py +++ b/synapse/util/gai_resolver.py @@ -18,9 +18,7 @@ from typing import ( TYPE_CHECKING, Callable, NoReturn, - Optional, Sequence, - Union, ) from zope.interface import implementer @@ -94,7 +92,7 @@ _GETADDRINFO_RESULT = list[ SocketKind, int, str, - Union[tuple[str, int], tuple[str, int, int, int], tuple[int, bytes]], + tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes], ] ] @@ -109,7 +107,7 @@ class GAIResolver: def __init__( self, reactor: IReactorThreads, - getThreadPool: Optional[Callable[[], "ThreadPool"]] = None, + getThreadPool: Callable[[], "ThreadPool"] | None = None, getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo, ): """ @@ -138,7 +136,7 @@ class GAIResolver: resolutionReceiver: IResolutionReceiver, hostName: str, portNumber: int = 0, - addressTypes: Optional[Sequence[type[IAddress]]] = None, + addressTypes: Sequence[type[IAddress]] | None = None, transportSemantics: str = "TCP", ) -> IHostResolution: """ diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py index 052863fdd6..c7a164d02e 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py @@ -22,7 +22,7 @@ """A circular doubly linked list implementation.""" import threading -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar P = TypeVar("P") LN = TypeVar("LN", bound="ListNode") @@ -47,10 +47,10 @@ class ListNode(Generic[P]): "next_node", ] - def __init__(self, cache_entry: Optional[P] = None) -> None: + def __init__(self, cache_entry: P | None = None) -> None: self.cache_entry = cache_entry - self.prev_node: Optional[ListNode[P]] = None - self.next_node: Optional[ListNode[P]] = None + self.prev_node: ListNode[P] | None = None + self.next_node: ListNode[P] | None = None @classmethod def create_root_node(cls: type["ListNode[P]"]) -> "ListNode[P]": @@ -149,7 +149,7 @@ class ListNode(Generic[P]): prev_node.next_node = self next_node.prev_node = self - def get_cache_entry(self) -> Optional[P]: + def get_cache_entry(self) -> P | None: """Get the cache entry, returns None if this is the root node (i.e. cache_entry is None) or if the entry has been dropped. """ diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index d683a57ab1..178b6fa377 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -22,7 +22,7 @@ """Utilities for manipulating macaroons""" -from typing import Callable, Literal, Optional +from typing import Callable, Literal import attr import pymacaroons @@ -52,7 +52,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: caveat in the macaroon, or if the caveat was not found in the macaroon. """ prefix = key + " = " - result: Optional[str] = None + result: str | None = None for caveat in macaroon.caveats: if not caveat.caveat_id.startswith(prefix): continue diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index dbf444e015..859e9a9072 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -21,7 +21,7 @@ import inspect import sys import traceback -from typing import Any, Optional +from typing import Any from twisted.conch import manhole_ssh from twisted.conch.insults import insults @@ -130,7 +130,7 @@ class SynapseManhole(ColoredManhole): class SynapseManholeInterpreter(ManholeInterpreter): - def showsyntaxerror(self, filename: Optional[str] = None) -> None: + def showsyntaxerror(self, filename: str | None = None) -> None: """Display the syntax error that just occurred. Overrides the base implementation, ignoring sys.excepthook. We always want diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 6d1adf1131..3daba79124 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -26,7 +26,6 @@ from typing import ( Awaitable, Callable, Generator, - Optional, Protocol, TypeVar, ) @@ -136,7 +135,7 @@ class HasClockAndServerName(Protocol): def measure_func( - name: Optional[str] = None, + name: str | None = None, ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: """Decorate an async method with a `Measure` context manager. @@ -220,7 +219,7 @@ class Measure: server_name=self.server_name, parent_context=parent_context, ) - self.start: Optional[float] = None + self.start: float | None = None def __enter__(self) -> "Measure": if self.start is not None: @@ -236,9 +235,9 @@ class Measure: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: if self.start is None: raise RuntimeError("Measure() block exited without being entered") diff --git a/synapse/util/pydantic_models.py b/synapse/util/pydantic_models.py index e1e2d8b99f..f1d35a35ec 100644 --- a/synapse/util/pydantic_models.py +++ b/synapse/util/pydantic_models.py @@ -13,7 +13,7 @@ # # -from typing import Annotated, Union +from typing import Annotated from pydantic import AfterValidator, BaseModel, ConfigDict, StrictStr, StringConstraints @@ -53,4 +53,4 @@ EventIdV1And2 = Annotated[StrictStr, AfterValidator(validate_event_id_v1_and_2)] EventIdV3Plus = Annotated[ StrictStr, StringConstraints(pattern=r"^\$([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$") ] -AnyEventId = Union[EventIdV1And2, EventIdV3Plus] +AnyEventId = EventIdV1And2 | EventIdV3Plus diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 37d2e4505d..024706d9cf 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -31,7 +31,6 @@ from typing import ( Iterator, Mapping, MutableSet, - Optional, ) from weakref import WeakSet @@ -164,7 +163,7 @@ class FederationRateLimiter: our_server_name: str, clock: Clock, config: FederationRatelimitSettings, - metrics_name: Optional[str] = None, + metrics_name: str | None = None, ): """ Args: @@ -217,7 +216,7 @@ class _PerHostRatelimiter: our_server_name: str, clock: Clock, config: FederationRatelimitSettings, - metrics_name: Optional[str] = None, + metrics_name: str | None = None, ): """ Args: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index ce747c3f19..8a5aab50f1 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -168,7 +168,7 @@ class RetryDestinationLimiter: hs: "HomeServer", clock: Clock, store: DataStore, - failure_ts: Optional[int], + failure_ts: int | None, retry_interval: int, backoff_on_404: bool = False, backoff_on_failure: bool = True, @@ -230,9 +230,9 @@ class RetryDestinationLimiter: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: success = exc_type is None valid_err_code = False diff --git a/synapse/util/rust.py b/synapse/util/rust.py index 37f43459f1..63b53b917f 100644 --- a/synapse/util/rust.py +++ b/synapse/util/rust.py @@ -24,7 +24,6 @@ import os import urllib.parse from hashlib import blake2b from importlib.metadata import Distribution, PackageNotFoundError -from typing import Optional import synapse from synapse.synapse_rust import get_rust_file_digest @@ -80,7 +79,7 @@ def _hash_rust_files_in_directory(directory: str) -> str: return hasher.hexdigest() -def get_synapse_source_directory() -> Optional[str]: +def get_synapse_source_directory() -> str | None: """Try and find the source directory of synapse for editable installs (like those used in development). diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 0dadafbc78..cc26c5181b 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -24,7 +24,7 @@ import random import re import secrets import string -from typing import Any, Iterable, Optional +from typing import Any, Iterable from netaddr import valid_ipv6 @@ -109,7 +109,7 @@ def assert_valid_client_secret(client_secret: str) -> None: ) -def parse_server_name(server_name: str) -> tuple[str, Optional[int]]: +def parse_server_name(server_name: str) -> tuple[str, int | None]: """Split a server name into host/port parts. Args: @@ -140,7 +140,7 @@ def parse_server_name(server_name: str) -> tuple[str, Optional[int]]: VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z") -def parse_and_validate_server_name(server_name: str) -> tuple[str, Optional[int]]: +def parse_and_validate_server_name(server_name: str) -> tuple[str, int | None]: """Split a server name into host/port parts and do some basic validation. Args: @@ -207,7 +207,7 @@ def valid_id_server_location(id_server: str) -> bool: return "#" not in path and "?" not in path -def parse_and_validate_mxc_uri(mxc: str) -> tuple[str, Optional[int], str]: +def parse_and_validate_mxc_uri(mxc: str) -> tuple[str, int | None, str]: """Parse the given string as an MXC URI Checks that the "server name" part is a valid server name @@ -285,7 +285,7 @@ def base62_encode(num: int, minwidth: int = 1) -> str: return pad + res -def non_null_str_or_none(val: Any) -> Optional[str]: +def non_null_str_or_none(val: Any) -> str | None: """Check that the arg is a string containing no null (U+0000) codepoints. If so, returns the given string unmodified; otherwise, returns None. diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 22b3bf8c15..3b4423a1ff 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -20,7 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Awaitable, Callable from twisted.python.failure import Failure @@ -116,7 +116,7 @@ class TaskScheduler: str, Callable[ [ScheduledTask], - Awaitable[tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], + Awaitable[tuple[TaskStatus, JsonMapping | None, str | None]], ], ] = {} self._run_background_tasks = hs.config.worker.run_background_tasks @@ -143,7 +143,7 @@ class TaskScheduler: self, function: Callable[ [ScheduledTask], - Awaitable[tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], + Awaitable[tuple[TaskStatus, JsonMapping | None, str | None]], ], action_name: str, ) -> None: @@ -167,9 +167,9 @@ class TaskScheduler: self, action: str, *, - resource_id: Optional[str] = None, - timestamp: Optional[int] = None, - params: Optional[JsonMapping] = None, + resource_id: str | None = None, + timestamp: int | None = None, + params: JsonMapping | None = None, ) -> str: """Schedule a new potentially resumable task. A function matching the specified `action` should've been registered with `register_action` before the task is run. @@ -220,10 +220,10 @@ class TaskScheduler: self, id: str, *, - timestamp: Optional[int] = None, - status: Optional[TaskStatus] = None, - result: Optional[JsonMapping] = None, - error: Optional[str] = None, + timestamp: int | None = None, + status: TaskStatus | None = None, + result: JsonMapping | None = None, + error: str | None = None, ) -> bool: """Update some task-associated values. This is exposed publicly so it can be used inside task functions, mainly to update the result or resume @@ -263,7 +263,7 @@ class TaskScheduler: error=error, ) - async def get_task(self, id: str) -> Optional[ScheduledTask]: + async def get_task(self, id: str) -> ScheduledTask | None: """Get a specific task description by id. Args: @@ -278,11 +278,11 @@ class TaskScheduler: async def get_tasks( self, *, - actions: Optional[list[str]] = None, - resource_id: Optional[str] = None, - statuses: Optional[list[TaskStatus]] = None, - max_timestamp: Optional[int] = None, - limit: Optional[int] = None, + actions: list[str] | None = None, + resource_id: str | None = None, + statuses: list[TaskStatus] | None = None, + max_timestamp: int | None = None, + limit: int | None = None, ) -> list[ScheduledTask]: """Get a list of tasks. Returns all the tasks if no args are provided. diff --git a/synapse/util/templates.py b/synapse/util/templates.py index fc5dbc069c..d399b167c1 100644 --- a/synapse/util/templates.py +++ b/synapse/util/templates.py @@ -23,7 +23,7 @@ import time import urllib.parse -from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Sequence import jinja2 @@ -34,7 +34,7 @@ if TYPE_CHECKING: def build_jinja_env( template_search_directories: Sequence[str], config: "HomeServerConfig", - autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None, + autoescape: bool | Callable[[str | None], bool] | None = None, ) -> jinja2.Environment: """Set up a Jinja2 environment to load templates from the given search path @@ -82,7 +82,7 @@ def build_jinja_env( def _create_mxc_to_http_filter( - public_baseurl: Optional[str], + public_baseurl: str | None, ) -> Callable[[str, int, int, str], str]: """Create and return a jinja2 filter that converts MXC urls to HTTP diff --git a/synapse/visibility.py b/synapse/visibility.py index 41b6198af0..16b39e6200 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -24,7 +24,6 @@ from enum import Enum, auto from typing import ( Collection, Final, - Optional, Sequence, ) @@ -162,7 +161,7 @@ async def filter_events_for_client( room_id ] = await storage.main.get_retention_policy_for_room(room_id) - def allowed(event: EventBase) -> Optional[EventBase]: + def allowed(event: EventBase) -> EventBase | None: state_after_event = event_id_to_state.get(event.event_id) filtered = _check_client_allowed_to_see_event( user_id=user_id, @@ -185,7 +184,7 @@ async def filter_events_for_client( # we won't have such a state. The only outliers that are returned here are the # user's own membership event, so we can just inspect that. - user_membership_event: Optional[EventBase] + user_membership_event: EventBase | None if event.type == EventTypes.Member and event.state_key == user_id: user_membership_event = event elif state_after_event is not None: @@ -349,9 +348,9 @@ def _check_client_allowed_to_see_event( always_include_ids: frozenset[str], sender_ignored: bool, retention_policy: RetentionPolicy, - state: Optional[StateMap[EventBase]], + state: StateMap[EventBase] | None, sender_erased: bool, -) -> Optional[EventBase]: +) -> EventBase | None: """Check with the given user is allowed to see the given event See `filter_events_for_client` for details about args diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index db77484f4c..d89f487d3d 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -22,7 +22,6 @@ import logging import logging.config import warnings from io import StringIO -from typing import Optional from unittest.mock import Mock from pyperf import perf_counter @@ -58,7 +57,7 @@ class LineCounter(LineOnlyReceiver): class Factory(ServerFactory): protocol = LineCounter wait_for: int - on_done: Optional[Deferred] + on_done: Deferred | None async def main(reactor: ISynapseReactor, loops: int) -> float: diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 34369a8746..0ef537841d 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,5 +1,3 @@ -from typing import Optional - from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from synapse.appservice import ApplicationService from synapse.config.ratelimiting import RatelimitSettings @@ -489,7 +487,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): # and limiter name. async def get_ratelimit_override_for_user( user_id: str, limiter_name: str - ) -> Optional[RatelimitOverride]: + ) -> RatelimitOverride | None: if user_id == test_user_id: return RatelimitOverride( per_second=0.1, diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 1943292a8f..bf55f261bb 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Any, Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Sequence from unittest.mock import Mock from twisted.internet.testing import MemoryReactor @@ -80,7 +80,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): async def get_json( url: str, args: Mapping[Any, Any], - headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], + headers: Mapping[str | bytes, Sequence[str | bytes]], ) -> list[JsonDict]: # Ensure the access token is passed as a header. if not headers or not headers.get(b"Authorization"): @@ -154,9 +154,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): async def get_json( url: str, args: Mapping[Any, Any], - headers: Optional[ - Mapping[Union[str, bytes], Sequence[Union[str, bytes]]] - ] = None, + headers: Mapping[str | bytes, Sequence[str | bytes]] | None = None, ) -> list[JsonDict]: # Ensure the access token is passed as a both a query param and in the headers. if not args.get(b"access_token"): @@ -216,7 +214,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): async def post_json_get_json( uri: str, post_json: Any, - headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], + headers: Mapping[str | bytes, Sequence[str | bytes]], ) -> JsonDict: # Ensure the access token is passed as both a header and query arg. if not headers.get(b"Authorization"): diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index f17957c206..3caf006386 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional, Sequence +from typing import Sequence from unittest.mock import AsyncMock, Mock from typing_extensions import TypeAlias @@ -190,9 +190,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.HomeserverTestCase): # return one txn to send, then no more old txns txns = [txn, None] - def take_txn( - *args: object, **kwargs: object - ) -> "defer.Deferred[Optional[Mock]]": + def take_txn(*args: object, **kwargs: object) -> "defer.Deferred[Mock | None]": return defer.succeed(txns.pop(0)) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) @@ -216,9 +214,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.HomeserverTestCase): txns = [txn, None] pop_txn = False - def take_txn( - *args: object, **kwargs: object - ) -> "defer.Deferred[Optional[Mock]]": + def take_txn(*args: object, **kwargs: object) -> "defer.Deferred[Mock | None]": if pop_txn: return defer.succeed(txns.pop(0)) else: @@ -254,9 +250,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.HomeserverTestCase): txns = [txn, None] pop_txn = False - def take_txn( - *args: object, **kwargs: object - ) -> "defer.Deferred[Optional[Mock]]": + def take_txn(*args: object, **kwargs: object) -> "defer.Deferred[Mock | None]": if pop_txn: return defer.succeed(txns.pop(0)) else: @@ -291,11 +285,11 @@ defer.Deferred[ tuple[ ApplicationService, Sequence[EventBase], - Optional[list[JsonDict]], - Optional[list[JsonDict]], - Optional[TransactionOneTimeKeysCount], - Optional[TransactionUnusedFallbackKeys], - Optional[DeviceListUpdates], + list[JsonDict] | None, + list[JsonDict] | None, + TransactionOneTimeKeysCount | None, + TransactionUnusedFallbackKeys | None, + DeviceListUpdates | None, ] ] """ diff --git a/tests/config/test_workers.py b/tests/config/test_workers.py index 3a21975b89..55439a502c 100644 --- a/tests/config/test_workers.py +++ b/tests/config/test_workers.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Any, Mapping, Optional +from typing import Any, Mapping from unittest.mock import Mock from immutabledict import immutabledict @@ -35,7 +35,7 @@ class WorkerDutyConfigTestCase(TestCase): def _make_worker_config( self, worker_app: str, - worker_name: Optional[str], + worker_name: str | None, extras: Mapping[str, Any] = _EMPTY_IMMUTABLEDICT, ) -> WorkerConfig: root_config = Mock() diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2eaf77e9dc..d3e8da97f8 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,7 +19,7 @@ # # import time -from typing import Any, Optional, cast +from typing import Any, cast from unittest.mock import Mock import attr @@ -87,7 +87,7 @@ class FakeRequest: @logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context( - self, val: ContextRequest, expected: Optional[ContextRequest] + self, val: ContextRequest, expected: ContextRequest | None ) -> ContextRequest: self.assertEqual(getattr(current_context(), "request", None), expected) return val diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py index 623ec67ed6..72ade45758 100644 --- a/tests/events/test_auto_accept_invites.py +++ b/tests/events/test_auto_accept_invites.py @@ -20,7 +20,7 @@ # import asyncio from http import HTTPStatus -from typing import Any, Optional, TypeVar, cast +from typing import Any, TypeVar, cast from unittest.mock import Mock import attr @@ -525,7 +525,7 @@ def generate_request_key() -> SyncRequestKey: def sync_join( testcase: HomeserverTestCase, user_id: str, - since_token: Optional[StreamToken] = None, + since_token: StreamToken | None = None, ) -> tuple[list[JoinedSyncResult], StreamToken]: """Perform a sync request for the given user and return the user join updates they've received, as well as the next_batch token. @@ -766,7 +766,7 @@ class MockEvent: type: str content: dict[str, Any] room_id: str = "!someroom" - state_key: Optional[str] = None + state_key: str | None = None def is_state(self) -> bool: """Checks if the event is a state event by checking if it has a state key.""" @@ -793,7 +793,7 @@ async def make_awaitable(value: T) -> T: def create_module( - config_override: Optional[dict[str, Any]] = None, worker_name: Optional[str] = None + config_override: dict[str, Any] | None = None, worker_name: str | None = None ) -> InviteAutoAccepter: # Create a mock based on the ModuleApi spec, but override some mocked functions # because some capabilities are needed for running the tests. diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index aa8d7454c0..4132050647 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Iterable, Optional, Union +from typing import Iterable from unittest.mock import AsyncMock, Mock import attr @@ -63,7 +63,7 @@ class LegacyPresenceRouterTestModule: } return users_to_state - async def get_interested_users(self, user_id: str) -> Union[set[str], str]: + async def get_interested_users(self, user_id: str) -> set[str] | str: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -113,7 +113,7 @@ class PresenceRouterTestModule: } return users_to_state - async def get_interested_users(self, user_id: str) -> Union[set[str], str]: + async def get_interested_users(self, user_id: str) -> set[str] | str: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -482,7 +482,7 @@ def send_presence_update( user_id: str, access_token: str, presence_state: str, - status_message: Optional[str] = None, + status_message: str | None = None, ) -> JsonDict: # Build the presence body body = {"presence": presence_state} @@ -510,7 +510,7 @@ def generate_request_key() -> SyncRequestKey: def sync_presence( testcase: HomeserverTestCase, user_id: str, - since_token: Optional[StreamToken] = None, + since_token: StreamToken | None = None, ) -> tuple[list[UserPresenceState], StreamToken]: """Perform a sync request for the given user and return the user presence updates they've received, as well as the next_batch token. diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 9d41067844..9ea015e138 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -20,7 +20,7 @@ # import unittest as stdlib_unittest -from typing import Any, Mapping, Optional +from typing import Any, Mapping import attr from parameterized import parameterized @@ -648,7 +648,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize( self, ev: EventBase, - fields: Optional[list[str]], + fields: list[str] | None, include_admin_metadata: bool = False, ) -> JsonDict: return serialize_event( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 34b552b9ed..fd1ef043bb 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -1,4 +1,4 @@ -from typing import Callable, Collection, Optional +from typing import Callable, Collection from unittest import mock from unittest.mock import AsyncMock, Mock @@ -72,7 +72,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): return config async def record_transaction( - self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] + self, txn: Transaction, json_cb: Callable[[], JsonDict] | None ) -> JsonDict: if json_cb is None: # The tests seem to expect that this method raises in this situation. diff --git a/tests/federation/test_federation_out_of_band_membership.py b/tests/federation/test_federation_out_of_band_membership.py index 905f9e6580..a1ab72b7a1 100644 --- a/tests/federation/test_federation_out_of_band_membership.py +++ b/tests/federation/test_federation_out_of_band_membership.py @@ -23,7 +23,7 @@ import logging import time import urllib.parse from http import HTTPStatus -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, TypeVar from unittest.mock import Mock import attr @@ -146,7 +146,7 @@ class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase): self.storage_controllers = hs.get_storage_controllers() def do_sync( - self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str + self, sync_body: JsonDict, *, since: str | None = None, tok: str ) -> tuple[JsonDict, str]: """Do a sliding sync request with given body. @@ -326,13 +326,13 @@ class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase): async def get_json( destination: str, path: str, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser[T]] = None, - ) -> Union[JsonDict, T]: + parser: ByteParser[T] | None = None, + ) -> JsonDict | T: if ( path == f"/_matrix/federation/v1/make_join/{urllib.parse.quote_plus(remote_room_id)}/{urllib.parse.quote_plus(local_user1_id)}" @@ -355,17 +355,17 @@ class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase): async def put_json( destination: str, path: str, - args: Optional[QueryParams] = None, - data: Optional[JsonDict] = None, - json_data_callback: Optional[Callable[[], JsonDict]] = None, + args: QueryParams | None = None, + data: JsonDict | None = None, + json_data_callback: Callable[[], JsonDict] | None = None, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser[T]] = None, + parser: ByteParser[T] | None = None, backoff_on_all_error_codes: bool = False, - ) -> Union[JsonDict, T, SendJoinResponse]: + ) -> JsonDict | T | SendJoinResponse: if ( path.startswith( f"/_matrix/federation/v2/send_join/{urllib.parse.quote_plus(remote_room_id)}/" @@ -508,17 +508,17 @@ class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase): async def put_json( destination: str, path: str, - args: Optional[QueryParams] = None, - data: Optional[JsonDict] = None, - json_data_callback: Optional[Callable[[], JsonDict]] = None, + args: QueryParams | None = None, + data: JsonDict | None = None, + json_data_callback: Callable[[], JsonDict] | None = None, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, - parser: Optional[ByteParser[T]] = None, + parser: ByteParser[T] | None = None, backoff_on_all_error_codes: bool = False, - ) -> Union[JsonDict, T]: + ) -> JsonDict | T: if path.startswith("/_matrix/federation/v1/send/") and data is not None: for pdu in data.get("pdus", []): event = event_from_pdu_json(pdu, room_version) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 20b67e3a73..ced98a8b00 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -17,7 +17,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Callable, Optional +from typing import Callable from unittest.mock import AsyncMock, Mock from signedjson import key, sign @@ -510,7 +510,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) async def record_transaction( - self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None + self, txn: Transaction, json_cb: Callable[[], JsonDict] | None = None ) -> JsonDict: assert json_cb is not None data = json_cb() @@ -592,7 +592,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # expect two edus self.assertEqual(len(self.edus), 2) - stream_id: Optional[int] = None + stream_id: int | None = None stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) @@ -754,7 +754,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # for each device, there should be a single update self.assertEqual(len(self.edus), 3) - stream_id: Optional[int] = None + stream_id: int | None = None for edu in self.edus: self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE) c = edu["content"] @@ -876,7 +876,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): edu: JsonDict, user_id: str, device_id: str, - prev_stream_id: Optional[int], + prev_stream_id: int | None, ) -> int: """Check that the given EDU is an update for the given device Returns the stream_id. diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b1371d0ac7..0d74791290 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -20,7 +20,6 @@ # import logging from http import HTTPStatus -from typing import Optional, Union from unittest.mock import Mock from parameterized import parameterized @@ -192,12 +191,12 @@ class MessageAcceptTests(unittest.FederatingHomeserverTestCase): async def post_json( destination: str, path: str, - data: Optional[JsonDict] = None, + data: JsonDict | None = None, long_retries: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ignore_backoff: bool = False, - args: Optional[QueryParams] = None, - ) -> Union[JsonDict, list]: + args: QueryParams | None = None, + ) -> JsonDict | list: # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index f538b67e41..9a6bbabd35 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -20,7 +20,6 @@ # import json -from typing import Optional from unittest.mock import Mock import ijson.common @@ -98,7 +97,7 @@ class SendJoinParserTestCase(TestCase): def test_servers_in_room(self) -> None: """Check that the servers_in_room field is correctly parsed""" - def parse(response: JsonDict) -> Optional[list[str]]: + def parse(response: JsonDict) -> list[str] | None: parser = SendJoinParser(RoomVersions.V1, False) serialised_response = json.dumps(response).encode() diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 9e92b06d91..ec705676cc 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -19,7 +19,7 @@ # # from collections import OrderedDict -from typing import Any, Optional +from typing import Any from twisted.internet.testing import MemoryReactor @@ -232,7 +232,7 @@ class FederationKnockingTestCase( # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. async def _check_event_auth( - origin: Optional[str], event: EventBase, context: EventContext + origin: str | None, event: EventBase, context: EventContext ) -> None: pass diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7d6bd35a9a..6336edb108 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -25,7 +25,6 @@ from typing import ( Awaitable, Callable, Iterable, - Optional, TypeVar, ) from unittest.mock import AsyncMock, Mock @@ -81,10 +80,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_run_as_background_process( desc: "LiteralString", - func: Callable[..., Awaitable[Optional[R]]], + func: Callable[..., Awaitable[R | None]], *args: Any, **kwargs: Any, - ) -> "defer.Deferred[Optional[R]]": + ) -> "defer.Deferred[R | None]": # Ignore linter error as this is used only for testing purposes (i.e. outside of Synapse). return run_as_background_process(desc, "test_server", func, *args, **kwargs) # type: ignore[untracked-background-process] @@ -293,7 +292,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): async def get_3pe_protocol( service: ApplicationService, protocol: str - ) -> Optional[JsonDict]: + ) -> JsonDict | None: if service == service_one: return { "x-protocol-data": 42, @@ -385,7 +384,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) def _mkservice( - self, is_interested_in_event: bool, protocols: Optional[Iterable] = None + self, is_interested_in_event: bool, protocols: Iterable | None = None ) -> Mock: """ Create a new mock representing an ApplicationService. @@ -1021,7 +1020,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): def _register_application_service( self, - namespaces: Optional[dict[str, Iterable[dict]]] = None, + namespaces: dict[str, Iterable[dict]] | None = None, ) -> ApplicationService: """ Register a new application service, with the given namespaces of interest. @@ -1316,8 +1315,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Capture what was sent as an AS transaction. self.send_mock.assert_called() last_args, _last_kwargs = self.send_mock.call_args - otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS] - unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ + otks: TransactionOneTimeKeysCount | None = last_args[self.ARG_OTK_COUNTS] + unused_fallbacks: TransactionUnusedFallbackKeys | None = last_args[ self.ARG_FALLBACK_KEYS ] diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index acefd707f5..648be7e7e7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from unittest.mock import AsyncMock import pymacaroons @@ -55,7 +54,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1 = self.register_user("a_user", "pass") - def token_login(self, token: str) -> Optional[str]: + def token_login(self, token: str) -> str | None: body = { "type": "m.login.token", "token": token, diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 5b04da8640..acd37a1c71 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -20,7 +20,6 @@ # # -from typing import Optional from unittest import mock from twisted.internet.defer import ensureDeferred @@ -312,8 +311,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): user_id: str, device_id: str, display_name: str, - access_token: Optional[str] = None, - ip: Optional[str] = None, + access_token: str | None = None, + ip: str | None = None, ) -> None: device_id = self.get_success( self.handler.check_device_registered( diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c9ece68729..7085531548 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -19,7 +19,7 @@ # # import logging -from typing import Collection, Optional, cast +from typing import Collection, cast from unittest import TestCase from unittest.mock import AsyncMock, Mock, patch @@ -689,7 +689,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): return is_partial_state async def sync_partial_state_room( - initial_destination: Optional[str], + initial_destination: str | None, other_destinations: Collection[str], room_id: str, ) -> None: @@ -744,7 +744,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): return is_partial_state async def sync_partial_state_room( - initial_destination: Optional[str], + initial_destination: str | None, other_destinations: Collection[str], room_id: str, ) -> None: diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 5771699a62..3d856b9346 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from unittest import mock from twisted.internet.testing import MemoryReactor @@ -183,7 +182,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): else: async def get_event( - destination: str, event_id: str, timeout: Optional[int] = None + destination: str, event_id: str, timeout: int | None = None ) -> JsonDict: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, prev_event.event_id) @@ -585,7 +584,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): room_state_endpoint_requested_count = 0 async def get_event( - destination: str, event_id: str, timeout: Optional[int] = None + destination: str, event_id: str, timeout: int | None = None ) -> None: nonlocal event_endpoint_requested_count event_endpoint_requested_count += 1 @@ -1115,7 +1114,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ): async def get_event( - destination: str, event_id: str, timeout: Optional[int] = None + destination: str, event_id: str, timeout: int | None = None ) -> JsonDict: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, missing_event.event_id) diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 43004bfc69..c0a197874e 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -25,7 +25,7 @@ import time from http import HTTPStatus from http.server import BaseHTTPRequestHandler, HTTPServer from io import BytesIO -from typing import Any, ClassVar, Coroutine, Generator, Optional, TypeVar, Union +from typing import Any, ClassVar, Coroutine, Generator, TypeVar, Union from unittest.mock import ANY, AsyncMock, Mock from urllib.parse import parse_qs @@ -759,7 +759,7 @@ class FakeMasServer(HTTPServer): secret: str = "verysecret" """The shared secret used to authenticate the introspection endpoint.""" - last_token_seen: Optional[str] = None + last_token_seen: str | None = None """What is the last access token seen by the introspection endpoint.""" calls: int = 0 @@ -1110,7 +1110,7 @@ class DisabledEndpointsTestCase(HomeserverTestCase): return config def expect_unauthorized( - self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" + self, method: str, path: str, content: bytes | str | JsonDict = "" ) -> None: channel = self.make_request(method, path, content, shorthand=False) @@ -1120,7 +1120,7 @@ class DisabledEndpointsTestCase(HomeserverTestCase): self, method: str, path: str, - content: Union[bytes, str, JsonDict] = "", + content: bytes | str | JsonDict = "", auth: bool = False, ) -> None: channel = self.make_request( @@ -1133,7 +1133,7 @@ class DisabledEndpointsTestCase(HomeserverTestCase): ) def expect_forbidden( - self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" + self, method: str, path: str, content: bytes | str | JsonDict = "" ) -> None: channel = self.make_request(method, path, content) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 3180969e7b..4583afb625 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -19,7 +19,7 @@ # # import os -from typing import Any, Awaitable, ContextManager, Optional +from typing import Any, Awaitable, ContextManager from unittest.mock import ANY, AsyncMock, Mock, patch from urllib.parse import parse_qs, urlparse @@ -221,7 +221,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return _build_callback_request(code, state, session), grant def assertRenderedError( - self, error: str, error_description: Optional[str] = None + self, error: str, error_description: str | None = None ) -> tuple[Any, ...]: self.render_error.assert_called_once() args = self.render_error.call_args[0] diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index faa269bd35..573ba58c4f 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -22,7 +22,7 @@ """Tests for the password_auth_provider interface""" from http import HTTPStatus -from typing import Any, Optional, Union +from typing import Any from unittest.mock import AsyncMock, Mock from twisted.internet.testing import MemoryReactor @@ -707,7 +707,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.called = False async def on_logged_out( - user_id: str, device_id: Optional[str], access_token: str + user_id: str, device_id: str | None, access_token: str ) -> None: self.called = True @@ -978,7 +978,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self, access_token: str, device: str, - body: Union[JsonDict, bytes] = b"", + body: JsonDict | bytes = b"", ) -> FakeChannel: """Delete an individual device.""" channel = self.make_request( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index de1bc90c67..44f1e6432d 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,7 +19,7 @@ # # import itertools -from typing import Optional, cast +from typing import cast from unittest.mock import Mock, call from parameterized import parameterized @@ -1650,7 +1650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(state.state, PresenceState.ONLINE) def _set_presencestate_with_status_msg( - self, state: str, status_msg: Optional[str] + self, state: str, status_msg: str | None ) -> None: """Set a PresenceState and status_msg and check the result. diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 20c2554e25..0db7f30b1f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -19,7 +19,7 @@ # # -from typing import Any, Collection, Optional +from typing import Any, Collection from unittest.mock import AsyncMock, Mock from twisted.internet.testing import MemoryReactor @@ -63,10 +63,10 @@ class TestSpamChecker: async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], - auth_provider_id: Optional[str], + auth_provider_id: str | None, ) -> RegistrationBehaviour: return RegistrationBehaviour.ALLOW @@ -74,10 +74,10 @@ class TestSpamChecker: class DenyAll(TestSpamChecker): async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], - auth_provider_id: Optional[str], + auth_provider_id: str | None, ) -> RegistrationBehaviour: return RegistrationBehaviour.DENY @@ -85,10 +85,10 @@ class DenyAll(TestSpamChecker): class BanAll(TestSpamChecker): async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], - auth_provider_id: Optional[str], + auth_provider_id: str | None, ) -> RegistrationBehaviour: return RegistrationBehaviour.SHADOW_BAN @@ -96,10 +96,10 @@ class BanAll(TestSpamChecker): class BanBadIdPUser(TestSpamChecker): async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], - auth_provider_id: Optional[str] = None, + auth_provider_id: str | None = None, ) -> RegistrationBehaviour: # Reject any user coming from CAS and whose username contains profanity if auth_provider_id == "cas" and username and "flimflob" in username: @@ -113,8 +113,8 @@ class TestLegacyRegistrationSpamChecker: async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], ) -> RegistrationBehaviour: return RegistrationBehaviour.ALLOW @@ -123,8 +123,8 @@ class TestLegacyRegistrationSpamChecker: class LegacyAllowAll(TestLegacyRegistrationSpamChecker): async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], ) -> RegistrationBehaviour: return RegistrationBehaviour.ALLOW @@ -133,8 +133,8 @@ class LegacyAllowAll(TestLegacyRegistrationSpamChecker): class LegacyDenyAll(TestLegacyRegistrationSpamChecker): async def check_registration_for_spam( self, - email_threepid: Optional[dict], - username: Optional[str], + email_threepid: dict | None, + username: str | None, request_info: Collection[tuple[str, str]], ) -> RegistrationBehaviour: return RegistrationBehaviour.DENY @@ -777,8 +777,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self, requester: Requester, localpart: str, - displayname: Optional[str], - password_hash: Optional[str] = None, + displayname: str | None, + password_hash: str | None = None, ) -> tuple[str, str]: """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. diff --git a/tests/handlers/test_room_list.py b/tests/handlers/test_room_list.py index f6e9309f1f..e7c4436d1d 100644 --- a/tests/handlers/test_room_list.py +++ b/tests/handlers/test_room_list.py @@ -1,5 +1,4 @@ from http import HTTPStatus -from typing import Optional from synapse.rest import admin from synapse.rest.client import directory, login, room @@ -18,7 +17,7 @@ class RoomListHandlerTestCase(unittest.HomeserverTestCase): ] def _create_published_room( - self, tok: str, extra_content: Optional[JsonDict] = None + self, tok: str, extra_content: JsonDict | None = None ) -> str: room_id = self.helper.create_room_as(tok=tok, extra_content=extra_content) channel = self.make_request( diff --git a/tests/handlers/test_room_policy.py b/tests/handlers/test_room_policy.py index 00da1d942f..ff212ab06e 100644 --- a/tests/handlers/test_room_policy.py +++ b/tests/handlers/test_room_policy.py @@ -12,7 +12,6 @@ # . # # -from typing import Optional from unittest import mock import signedjson @@ -113,7 +112,7 @@ class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): async def get_policy_recommendation_for_pdu( destination: str, pdu: EventBase, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> JsonDict: self.call_count += 1 self.assertEqual(destination, self.OTHER_SERVER_NAME) @@ -128,8 +127,8 @@ class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): # Mock policy server actions on signing events async def policy_server_signs_event( - destination: str, pdu: EventBase, timeout: Optional[int] = None - ) -> Optional[JsonDict]: + destination: str, pdu: EventBase, timeout: int | None = None + ) -> JsonDict | None: sigs = compute_event_signature( pdu.room_version, pdu.get_dict(), @@ -139,8 +138,8 @@ class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): return sigs async def policy_server_signs_event_with_wrong_key( - destination: str, pdu: EventBase, timeout: Optional[int] = None - ) -> Optional[JsonDict]: + destination: str, pdu: EventBase, timeout: int | None = None + ) -> JsonDict | None: sk = signedjson.key.generate_signing_key("policy_server") sigs = compute_event_signature( pdu.room_version, @@ -151,13 +150,13 @@ class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): return sigs async def policy_server_refuses_to_sign_event( - destination: str, pdu: EventBase, timeout: Optional[int] = None - ) -> Optional[JsonDict]: + destination: str, pdu: EventBase, timeout: int | None = None + ) -> JsonDict | None: return {} async def policy_server_event_sign_error( - destination: str, pdu: EventBase, timeout: Optional[int] = None - ) -> Optional[JsonDict]: + destination: str, pdu: EventBase, timeout: int | None = None + ) -> JsonDict | None: return None self.policy_server_signs_event = policy_server_signs_event @@ -167,7 +166,7 @@ class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): policy_server_signs_event_with_wrong_key ) - def _add_policy_server_to_room(self, public_key: Optional[str] = None) -> None: + def _add_policy_server_to_room(self, public_key: str | None = None) -> None: # Inject a member event into the room policy_user_id = f"@policy:{self.OTHER_SERVER_NAME}" self.get_success( @@ -442,7 +441,7 @@ class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): f"event did not include policy server signature, signature block = {ev.get('signatures', None)}", ) - def _fetch_federation_event(self, event_id: str) -> Optional[JsonDict]: + def _fetch_federation_event(self, event_id: str) -> JsonDict | None: # Request federation events to see the signatures channel = self.make_request( "POST", diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 3c8c483921..ee65cb1afb 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Any, Iterable, Optional +from typing import Any, Iterable from unittest import mock from twisted.internet.defer import ensureDeferred @@ -49,7 +49,7 @@ from tests.unittest import override_config def _create_event( - room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0 + room_id: str, order: Any | None = None, origin_server_ts: int = 0 ) -> mock.Mock: result = mock.Mock(name=room_id) result.room_id = room_id @@ -151,8 +151,8 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): space_id: str, room_id: str, token: str, - order: Optional[str] = None, - via: Optional[list[str]] = None, + order: str | None = None, + via: list[str] | None = None, ) -> None: """Add a child room to a space.""" if via is None: @@ -393,7 +393,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_hierarchy(result2, [(self.space, [self.room])]) def _create_room_with_join_rule( - self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any + self, join_rule: str, room_version: str | None = None, **extra_content: Any ) -> str: """Create a room with the given join rule and add it to the space.""" room_id = self.helper.create_room_as( @@ -740,7 +740,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): async def summarize_remote_room_hierarchy( _self: Any, room: Any, suggested_only: bool - ) -> tuple[Optional[_RoomEntry], dict[str, JsonDict], set[str]]: + ) -> tuple[_RoomEntry | None, dict[str, JsonDict], set[str]]: return requested_room_entry, {subroom: child_room}, set() # Add a room to the space which is on another server. @@ -793,7 +793,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): async def summarize_remote_room_hierarchy( _self: Any, room: Any, suggested_only: bool - ) -> tuple[Optional[_RoomEntry], dict[str, JsonDict], set[str]]: + ) -> tuple[_RoomEntry | None, dict[str, JsonDict], set[str]]: return requested_room_entry, {fed_subroom: child_room}, set() expected = [ @@ -921,7 +921,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): async def summarize_remote_room_hierarchy( _self: Any, room: Any, suggested_only: bool - ) -> tuple[Optional[_RoomEntry], dict[str, JsonDict], set[str]]: + ) -> tuple[_RoomEntry | None, dict[str, JsonDict], set[str]]: return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. @@ -985,7 +985,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): async def summarize_remote_room_hierarchy( _self: Any, room: Any, suggested_only: bool - ) -> tuple[Optional[_RoomEntry], dict[str, JsonDict], set[str]]: + ) -> tuple[_RoomEntry | None, dict[str, JsonDict], set[str]]: return fed_room_entry, {}, set() # Add a room to the space which is on another server. @@ -1120,7 +1120,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): async def summarize_remote_room_hierarchy( _self: Any, room: Any, suggested_only: bool - ) -> tuple[Optional[_RoomEntry], dict[str, JsonDict], set[str]]: + ) -> tuple[_RoomEntry | None, dict[str, JsonDict], set[str]]: return requested_room_entry, {fed_subroom: child_room}, set() expected = [ @@ -1233,7 +1233,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): async def summarize_remote_room_hierarchy( _self: Any, room: Any, suggested_only: bool - ) -> tuple[Optional[_RoomEntry], dict[str, JsonDict], set[str]]: + ) -> tuple[_RoomEntry | None, dict[str, JsonDict], set[str]]: return requested_room_entry, {}, set() with mock.patch( diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 28159abbcb..c2aeab5f7e 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -19,7 +19,7 @@ # # -from typing import Any, Optional +from typing import Any from unittest.mock import AsyncMock, Mock import attr @@ -61,7 +61,7 @@ BASE_URL = "https://synapse/" class FakeAuthnResponse: ava = attr.ib(type=dict) assertions = attr.ib(type=list, factory=list) - in_response_to = attr.ib(type=Optional[str], default=None) + in_response_to = attr.ib(type=(str | None), default=None) class TestMappingProvider: diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index d033ed3a1c..eea88cd136 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -20,7 +20,7 @@ # -from typing import Callable, Union +from typing import Callable from unittest.mock import patch from zope.interface import implementer @@ -104,7 +104,7 @@ class _DummyMessage: class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): - ip_class: Union[type[IPv4Address], type[IPv6Address]] = IPv4Address + ip_class: type[IPv4Address] | type[IPv6Address] = IPv4Address def setUp(self) -> None: super().setUp() diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index a35910e4dd..4582906441 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -18,7 +18,7 @@ # # import logging -from typing import AbstractSet, Mapping, Optional +from typing import AbstractSet, Mapping from unittest.mock import patch import attr @@ -62,7 +62,7 @@ class RoomSyncConfigTestCase(TestCase): self, actual: RoomSyncConfig, expected: RoomSyncConfig, - message_prefix: Optional[str] = None, + message_prefix: str | None = None, ) -> None: self.assertEqual(actual.timeline_limit, expected.timeline_limit, message_prefix) @@ -3277,7 +3277,7 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): self, user: UserID, to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> tuple[dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]: """ Get the rooms the user should be syncing with @@ -3614,7 +3614,7 @@ class SortRoomsTestCase(HomeserverTestCase): self, user: UserID, to_token: StreamToken, - from_token: Optional[StreamToken], + from_token: StreamToken | None, ) -> tuple[dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]: """ Get the rooms the user should be syncing with @@ -3828,10 +3828,10 @@ class RequiredStateChangesTestParameters: request_required_state_map: dict[str, set[str]] state_deltas: StateMap[str] expected_with_state_deltas: tuple[ - Optional[Mapping[str, AbstractSet[str]]], StateFilter + Mapping[str, AbstractSet[str]] | None, StateFilter ] expected_without_state_deltas: tuple[ - Optional[Mapping[str, AbstractSet[str]]], StateFilter + Mapping[str, AbstractSet[str]] | None, StateFilter ] diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py index 5ac088f601..95595b8ff9 100644 --- a/tests/handlers/test_sso.py +++ b/tests/handlers/test_sso.py @@ -18,7 +18,7 @@ # # from http import HTTPStatus -from typing import BinaryIO, Callable, Optional +from typing import BinaryIO, Callable from unittest.mock import Mock from twisted.internet.testing import MemoryReactor @@ -117,9 +117,9 @@ class TestSSOHandler(unittest.HomeserverTestCase): async def mock_get_file( url: str, output_stream: BinaryIO, - max_size: Optional[int] = None, - headers: Optional[RawHeaders] = None, - is_allowed_content_type: Optional[Callable[[str], bool]] = None, + max_size: int | None = None, + headers: RawHeaders | None = None, + is_allowed_content_type: Callable[[str], bool] | None = None, ) -> tuple[int, dict[bytes, list[bytes]], str, int]: fake_response = FakeResponse(code=404) if url == "http://my.server/me.png": diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 94f5e472ca..0072327044 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -18,7 +18,7 @@ # # -from typing import Any, Optional, cast +from typing import Any, cast from twisted.internet.testing import MemoryReactor @@ -74,9 +74,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - async def get_all_room_state(self) -> list[Optional[str]]: + async def get_all_room_state(self) -> list[str | None]: rows = cast( - list[tuple[Optional[str]]], + list[tuple[str | None]], await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("topic",) ), @@ -85,7 +85,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def _get_current_stats( self, stats_type: str, stat_id: str - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: table, id_col = stats.TYPE_TO_TABLE[stats_type] cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 140dd4a0ba..18ec2ca6b6 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -18,7 +18,7 @@ # # from http import HTTPStatus -from typing import Collection, ContextManager, Optional +from typing import Collection, ContextManager from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class @@ -893,7 +893,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): federation_event_handler = self.hs.get_federation_event_handler() async def _check_event_auth( - origin: Optional[str], event: EventBase, context: EventContext + origin: str | None, event: EventBase, context: EventContext ) -> None: pass @@ -1117,8 +1117,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def generate_sync_config( user_id: str, - device_id: Optional[str] = "device_id", - filter_collection: Optional[FilterCollection] = None, + device_id: str | None = "device_id", + filter_collection: FilterCollection | None = None, use_state_after: bool = False, ) -> SyncConfig: """Generate a sync config (with a unique request key). diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 949564fcc7..49ecaa30ff 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -20,7 +20,7 @@ import base64 import logging import os -from typing import Generator, Optional, cast +from typing import Generator, cast from unittest.mock import AsyncMock, call, patch import treq @@ -85,7 +85,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.tls_factory = FederationPolicyForHTTPS(config) - self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache( + self.well_known_cache: TTLCache[bytes, bytes | None] = TTLCache( cache_name="test_cache", server_name="test_server", timer=self.reactor.seconds, @@ -109,8 +109,8 @@ class MatrixFederationAgentTests(unittest.TestCase): self, client_factory: IProtocolFactory, ssl: bool = True, - expected_sni: Optional[bytes] = None, - tls_sanlist: Optional[list[bytes]] = None, + expected_sni: bytes | None = None, + tls_sanlist: list[bytes] | None = None, ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection Args: @@ -228,7 +228,7 @@ class MatrixFederationAgentTests(unittest.TestCase): client_factory: IProtocolFactory, expected_sni: bytes, content: bytes, - response_headers: Optional[dict] = None, + response_headers: dict | None = None, ) -> HTTPChannel: """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -257,7 +257,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self, request: Request, content: bytes, - headers: Optional[dict] = None, + headers: dict | None = None, ) -> None: """Check that an incoming request looks like a valid .well-known request, and send back the response. @@ -397,7 +397,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def _do_get_via_proxy( self, expect_proxy_ssl: bool = False, - expected_auth_credentials: Optional[bytes] = None, + expected_auth_credentials: bytes | None = None, ) -> None: """Send a https federation request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index cc9b5fd6e1..afa69d1b7b 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -27,9 +27,7 @@ from typing import ( Callable, ContextManager, Generator, - Optional, TypeVar, - Union, ) from unittest import mock from unittest.mock import Mock @@ -65,8 +63,8 @@ def test_disconnect( reactor: MemoryReactorClock, channel: FakeChannel, expect_cancellation: bool, - expected_body: Union[bytes, JsonDict], - expected_code: Optional[int] = None, + expected_body: bytes | JsonDict, + expected_code: int | None = None, ) -> None: """Disconnects an in-flight request and checks the response. @@ -146,9 +144,9 @@ def make_request_with_cancellation_test( site: Site, method: str, path: str, - content: Union[bytes, str, JsonDict] = b"", + content: bytes | str | JsonDict = b"", *, - token: Optional[str] = None, + token: str | None = None, ) -> FakeChannel: """Performs a request repeatedly, disconnecting at successive `await`s, until one completes. @@ -361,7 +359,7 @@ class Deferred__await__Patch: # unresolved `Deferred` and return it out of `Deferred.__await__` / # `coroutine.send()`. We have to resolve it later, in case the `await`ing # coroutine is part of some shared processing, such as `@cached`. - self._to_unblock: dict[Deferred, Union[object, Failure]] = {} + self._to_unblock: dict[Deferred, object | Failure] = {} # The last stack we logged. self._previous_stack: list[inspect.FrameInfo] = [] diff --git a/tests/http/test_client.py b/tests/http/test_client.py index d9eaa78a39..5c8c1220e4 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -20,7 +20,6 @@ # from io import BytesIO -from typing import Union from unittest.mock import Mock from netaddr import IPSet @@ -58,7 +57,7 @@ class ReadMultipartResponseTests(TestCase): redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" def _build_multipart_response( - self, response_length: Union[int, str], max_length: int + self, response_length: int | str, max_length: int ) -> tuple[ BytesIO, "Deferred[MultipartResponse]", @@ -208,7 +207,7 @@ class ReadMultipartResponseTests(TestCase): class ReadBodyWithMaxSizeTests(TestCase): def _build_response( - self, length: Union[int, str] = UNKNOWN_LENGTH + self, length: int | str = UNKNOWN_LENGTH ) -> tuple[ BytesIO, "Deferred[int]", diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index a9b4f3d956..c65115b3e5 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -21,7 +21,6 @@ import base64 import logging import os -from typing import Optional from unittest.mock import patch import treq @@ -195,7 +194,7 @@ class ProxyParserTests(TestCase): expected_scheme: bytes, expected_hostname: bytes, expected_port: int, - expected_credentials: Optional[bytes], + expected_credentials: bytes | None, ) -> None: """ Tests that a given proxy URL will be broken into the components. @@ -251,8 +250,8 @@ class ProxyAgentTests(TestCase): client_factory: IProtocolFactory, server_factory: IProtocolFactory, ssl: bool = False, - expected_sni: Optional[bytes] = None, - tls_sanlist: Optional[list[bytes]] = None, + expected_sni: bytes | None = None, + tls_sanlist: list[bytes] | None = None, ) -> IProtocol: """Builds a test server, and completes the outgoing client connection @@ -602,7 +601,7 @@ class ProxyAgentTests(TestCase): self, proxy_config: ProxyConfig, expect_proxy_ssl: bool = False, - expected_auth_credentials: Optional[bytes] = None, + expected_auth_credentials: bytes | None = None, ) -> None: """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. @@ -682,7 +681,7 @@ class ProxyAgentTests(TestCase): self, proxy_config: ProxyConfig, expect_proxy_ssl: bool = False, - expected_auth_credentials: Optional[bytes] = None, + expected_auth_credentials: bytes | None = None, ) -> None: """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index 087191b220..5bf8305d05 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -21,7 +21,6 @@ import json from http import HTTPStatus from io import BytesIO -from typing import Union from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -40,7 +39,7 @@ from tests import unittest from tests.http.server._base import test_disconnect -def make_request(content: Union[bytes, JsonDict]) -> Mock: +def make_request(content: bytes | JsonDict) -> Mock: """Make an object that acts enough like a request.""" request = Mock(spec=["method", "uri", "content"]) diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index 2f389f7f44..3aaa743265 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -19,7 +19,7 @@ # # -from typing import Awaitable, Optional, cast +from typing import Awaitable, cast from twisted.internet import defer from twisted.internet.testing import MemoryReactorClock @@ -329,7 +329,7 @@ class LogContextScopeManagerTestCase(TestCase): reactor, clock = get_clock() callback_finished = False - active_span_in_callback: Optional[jaeger_client.Span] = None + active_span_in_callback: jaeger_client.Span | None = None async def bg_task() -> None: nonlocal callback_finished, active_span_in_callback @@ -391,7 +391,7 @@ class LogContextScopeManagerTestCase(TestCase): reactor, clock = get_clock() callback_finished = False - active_span_in_callback: Optional[jaeger_client.Span] = None + active_span_in_callback: jaeger_client.Span | None = None async def bg_task() -> None: nonlocal callback_finished, active_span_in_callback @@ -461,7 +461,7 @@ class LogContextScopeManagerTestCase(TestCase): span.span_id: span.operation_name for span in self._reporter.get_spans() } - def get_span_friendly_name(span_id: Optional[int]) -> str: + def get_span_friendly_name(span_id: int | None) -> str: if span_id is None: return "None" diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py index 6dba214514..f27a9ed685 100644 --- a/tests/media/test_media_retention.py +++ b/tests/media/test_media_retention.py @@ -20,7 +20,7 @@ # import io -from typing import Iterable, Optional +from typing import Iterable from matrix_common.types.mxc_uri import MXCUri @@ -63,9 +63,9 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): media_repository = hs.get_media_repository() def _create_media_and_set_attributes( - last_accessed_ms: Optional[int], - is_quarantined: Optional[bool] = False, - is_protected: Optional[bool] = False, + last_accessed_ms: int | None, + is_quarantined: bool | None = False, + is_protected: bool | None = False, ) -> MXCUri: # "Upload" some media to the local media store # If the meda @@ -113,8 +113,8 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): def _cache_remote_media_and_set_attributes( media_id: str, - last_accessed_ms: Optional[int], - is_quarantined: Optional[bool] = False, + last_accessed_ms: int | None, + is_quarantined: bool | None = False, ) -> MXCUri: # Pretend to cache some remote media self.get_success( diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index d584ea951c..e56354e0b3 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -23,7 +23,7 @@ import shutil import tempfile from binascii import unhexlify from io import BytesIO -from typing import Any, BinaryIO, ClassVar, Literal, Optional, Union +from typing import Any, BinaryIO, ClassVar, Literal from unittest.mock import MagicMock, Mock, patch from urllib import parse @@ -150,8 +150,8 @@ class TestImage: data: bytes content_type: bytes extension: bytes - expected_cropped: Optional[bytes] = None - expected_scaled: Optional[bytes] = None + expected_cropped: bytes | None = None + expected_scaled: bytes | None = None expected_found: bool = True unable_to_thumbnail: bool = False is_inline: bool = True @@ -302,7 +302,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): "Deferred[tuple[bytes, tuple[int, dict[bytes, list[bytes]]]]]", str, str, - Optional[QueryParams], + QueryParams | None, ] ] = [] @@ -313,7 +313,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): download_ratelimiter: Ratelimiter, ip_address: Any, max_size: int, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, ignore_backoff: bool = False, follow_redirects: bool = False, @@ -376,7 +376,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): return resources def _req( - self, content_disposition: Optional[bytes], include_content_type: bool = True + self, content_disposition: bytes | None, include_content_type: bool = True ) -> FakeChannel: channel = self.make_request( "GET", @@ -654,7 +654,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _test_thumbnail( self, method: str, - expected_body: Optional[bytes], + expected_body: bytes | None, expected_found: bool, unable_to_thumbnail: bool = False, ) -> None: @@ -868,7 +868,7 @@ class TestSpamCheckerLegacy: def parse_config(config: dict[str, Any]) -> dict[str, Any]: return config - async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]: + async def check_event_for_spam(self, event: EventBase) -> bool | str: return False # allow all events async def user_may_invite( @@ -972,7 +972,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> Union[Codes, Literal["NOT_SPAM"], tuple[Codes, JsonDict]]: + ) -> Codes | Literal["NOT_SPAM"] | tuple[Codes, JsonDict]: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write) @@ -1259,7 +1259,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): def read_body( - response: IResponse, stream: ByteWriteable, max_size: Optional[int] + response: IResponse, stream: ByteWriteable, max_size: int | None ) -> Deferred: d: Deferred = defer.Deferred() stream.write(SMALL_PNG) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index b768a913d7..12c8942bc8 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Any, Optional +from typing import Any from unittest.mock import AsyncMock, Mock from twisted.internet import defer @@ -803,10 +803,10 @@ class ModuleApiTestCase(BaseModuleApiTestCase): ) # Setup a callback counting the number of pushers. - number_of_pushers_in_callback: Optional[int] = None + number_of_pushers_in_callback: int | None = None async def _on_logged_out_mock( - user_id: str, device_id: Optional[str], access_token: str + user_id: str, device_id: str | None, access_token: str ) -> None: nonlocal number_of_pushers_in_callback number_of_pushers_in_callback = len( diff --git a/tests/module_api/test_spamchecker.py b/tests/module_api/test_spamchecker.py index d461d6cea2..42ef969ce0 100644 --- a/tests/module_api/test_spamchecker.py +++ b/tests/module_api/test_spamchecker.py @@ -12,7 +12,7 @@ # . # # -from typing import Literal, Union +from typing import Literal from twisted.internet.testing import MemoryReactor @@ -59,7 +59,7 @@ class SpamCheckerTestCase(HomeserverTestCase): async def user_may_create_room( user_id: str, room_config: JsonDict - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Literal["NOT_SPAM"] | Codes: self.last_room_config = room_config self.last_user_id = user_id return "NOT_SPAM" @@ -82,7 +82,7 @@ class SpamCheckerTestCase(HomeserverTestCase): async def user_may_create_room( user_id: str, room_config: JsonDict - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Literal["NOT_SPAM"] | Codes: self.last_room_config = room_config self.last_user_id = user_id return "NOT_SPAM" @@ -117,7 +117,7 @@ class SpamCheckerTestCase(HomeserverTestCase): async def user_may_create_room( user_id: str, room_config: JsonDict - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Literal["NOT_SPAM"] | Codes: self.last_room_config = room_config self.last_user_id = user_id return "NOT_SPAM" @@ -156,7 +156,7 @@ class SpamCheckerTestCase(HomeserverTestCase): async def user_may_create_room( user_id: str, room_config: JsonDict - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Literal["NOT_SPAM"] | Codes: self.last_room_config = room_config self.last_user_id = user_id return Codes.UNAUTHORIZED @@ -181,7 +181,7 @@ class SpamCheckerTestCase(HomeserverTestCase): async def user_may_create_room( user_id: str, - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Literal["NOT_SPAM"] | Codes: self.last_user_id = user_id return "NOT_SPAM" @@ -205,7 +205,7 @@ class SpamCheckerTestCase(HomeserverTestCase): event_type: str, state_key: str, content: JsonDict, - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Literal["NOT_SPAM"] | Codes: self.last_user_id = user_id self.last_room_id = room_id self.last_event_type = event_type @@ -255,7 +255,7 @@ class SpamCheckerTestCase(HomeserverTestCase): event_type: str, state_key: str, content: JsonDict, - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Literal["NOT_SPAM"] | Codes: return Codes.FORBIDDEN self._module_api.register_spam_checker_callbacks( diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 560d7234ec..137bbe24b2 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -20,7 +20,7 @@ # from http import HTTPStatus -from typing import Any, Optional +from typing import Any from unittest.mock import AsyncMock, patch from parameterized import parameterized @@ -210,7 +210,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): def _create_and_process( self, bulk_evaluator: BulkPushRuleEvaluator, - content: Optional[JsonDict] = None, + content: JsonDict | None = None, type: str = "test", ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py index 4982a80cce..2558f2c0b2 100644 --- a/tests/push/test_presentable_names.py +++ b/tests/push/test_presentable_names.py @@ -19,7 +19,7 @@ # # -from typing import Iterable, Optional, cast +from typing import Iterable, cast from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions @@ -59,7 +59,7 @@ class MockDataStore: async def get_event( self, event_id: str, allow_none: bool = False - ) -> Optional[FrozenEvent]: + ) -> FrozenEvent | None: assert allow_none, "Mock not configured for allow_none = False" # Decode the state key from the event ID. @@ -81,7 +81,7 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase): user_id: str = "", fallback_to_members: bool = True, fallback_to_single_member: bool = True, - ) -> Optional[str]: + ) -> str | None: # Encode the state key into the event ID. room_state_ids = {k[0]: "|".join(k[0]) for k in events} diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index b1f7ba6973..a786d74bf1 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -19,7 +19,7 @@ # # -from typing import Any, Optional, Union, cast +from typing import Any, cast from twisted.internet.testing import MemoryReactor @@ -148,7 +148,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): self, content: JsonMapping, *, - related_events: Optional[JsonDict] = None, + related_events: JsonDict | None = None, msc4210: bool = False, msc4306: bool = False, ) -> PushRuleEvaluator: @@ -165,7 +165,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ) room_member_count = 0 sender_power_level = 0 - power_levels: dict[str, Union[int, dict[str, int]]] = {} + power_levels: dict[str, int | dict[str, int]] = {} return PushRuleEvaluator( _flatten_dict(event), False, @@ -205,13 +205,13 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) def _assert_matches( - self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None + self, condition: JsonDict, content: JsonMapping, msg: str | None = None ) -> None: evaluator = self._get_evaluator(content) self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) def _assert_not_matches( - self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None + self, condition: JsonDict, content: JsonDict, msg: str | None = None ) -> None: evaluator = self._get_evaluator(content) self.assertFalse( @@ -588,7 +588,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): This tests the behaviour of tweaks_for_actions. """ - actions: list[Union[dict[str, str], str]] = [ + actions: list[dict[str, str] | str] = [ {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}, "notify", diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 84bdc84ce9..b23696668f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -19,7 +19,7 @@ # import logging from collections import defaultdict -from typing import Any, Optional +from typing import Any from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol, connectionDone @@ -105,8 +105,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): repl_handler, ) - self._client_transport: Optional[FakeTransport] = None - self._server_transport: Optional[FakeTransport] = None + self._client_transport: FakeTransport | None = None + self._server_transport: FakeTransport | None = None def create_resource_dict(self) -> dict[str, Resource]: d = super().create_resource_dict() @@ -325,7 +325,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return resource def make_worker_hs( - self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + self, worker_app: str, extra_config: dict | None = None, **kwargs: Any ) -> HomeServer: """Make a new worker HS instance, correctly connecting replication stream to the master HS. @@ -534,7 +534,7 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" - transport: Optional[FakeTransport] = None + transport: FakeTransport | None = None def __init__(self, server: FakeRedisPubSubServer): self._server = server diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py index fb99cb2335..7b757e9e9e 100644 --- a/tests/replication/storage/_base.py +++ b/tests/replication/storage/_base.py @@ -19,7 +19,7 @@ # # -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable from unittest.mock import Mock from twisted.internet.testing import MemoryReactor @@ -56,8 +56,8 @@ class BaseWorkerStoreTestCase(BaseStreamTestCase): self, method: str, args: Iterable[Any], - expected_result: Optional[Any] = None, - asserter: Optional[Callable[[Any, Any, Optional[Any]], None]] = None, + expected_result: Any | None = None, + asserter: Callable[[Any, Any, Any | None], None] | None = None, ) -> None: if asserter is None: asserter = self.assertEqual diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index 1398689c2d..28bfb8b8ea 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -19,7 +19,7 @@ # # import logging -from typing import Any, Iterable, Optional +from typing import Any, Iterable from canonicaljson import encode_canonical_json from parameterized import parameterized @@ -66,7 +66,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): ) def assertEventsEqual( - self, first: EventBase, second: EventBase, msg: Optional[Any] = None + self, first: EventBase, second: EventBase, msg: Any | None = None ) -> None: self.assertEqual( encode_canonical_json(first.get_pdu_json()), @@ -241,13 +241,13 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): sender: str = USER_ID, room_id: str = ROOM_ID, type: str = "m.room.message", - key: Optional[str] = None, - internal: Optional[dict] = None, - depth: Optional[int] = None, - prev_events: Optional[list[tuple[str, dict]]] = None, - auth_events: Optional[list[str]] = None, - prev_state: Optional[list[str]] = None, - redacts: Optional[str] = None, + key: str | None = None, + internal: dict | None = None, + depth: int | None = None, + prev_events: list[tuple[str, dict]] | None = None, + auth_events: list[str] | None = None, + prev_state: list[str] | None = None, + redacts: str | None = None, push_actions: Iterable = frozenset(), **content: object, ) -> tuple[EventBase, EventContext]: diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 9607c03224..484fd6b6db 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -18,7 +18,7 @@ # # -from typing import Any, Optional +from typing import Any from parameterized import parameterized @@ -517,7 +517,7 @@ class EventsStreamTestCase(BaseStreamTestCase): event_count = 0 def _inject_test_event( - self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any + self, body: str | None = None, sender: str | None = None, **kwargs: Any ) -> EventBase: if sender is None: sender = self.user_id @@ -539,9 +539,9 @@ class EventsStreamTestCase(BaseStreamTestCase): def _inject_state_event( self, - body: Optional[str] = None, - state_key: Optional[str] = None, - sender: Optional[str] = None, + body: str | None = None, + state_key: str | None = None, + sender: str | None = None, ) -> EventBase: if sender is None: sender = self.user_id diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 193c6c0198..8dbb989850 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -20,7 +20,7 @@ # import logging import os -from typing import Any, Optional +from typing import Any from twisted.internet.protocol import Factory from twisted.internet.testing import MemoryReactor @@ -44,7 +44,7 @@ from tests.unittest import override_config logger = logging.getLogger(__name__) -test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None +test_server_connection_factory: TestServerTLSConnectionFactory | None = None class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): @@ -67,7 +67,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): return conf def make_worker_hs( - self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + self, worker_app: str, extra_config: dict | None = None, **kwargs: Any ) -> HomeServer: worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) # Force the media paths onto the replication resource. @@ -282,7 +282,7 @@ class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): return conf def make_worker_hs( - self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + self, worker_app: str, extra_config: dict | None = None, **kwargs: Any ) -> HomeServer: worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) # Force the media paths onto the replication resource. diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 5586bb47e1..561566de76 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from parameterized import parameterized @@ -273,8 +272,8 @@ class FederationTestCase(unittest.HomeserverTestCase): def _order_test( expected_destination_list: list[str], - order_by: Optional[str], - dir: Optional[str] = None, + order_by: str | None, + dir: str | None = None, ) -> None: """Request the list of destinations in a certain order. Assert that order is what we expect @@ -366,7 +365,7 @@ class FederationTestCase(unittest.HomeserverTestCase): """Test that searching for a destination works correctly""" def _search_test( - expected_destination: Optional[str], + expected_destination: str | None, search_term: str, ) -> None: """Search for a destination and check that the returned destinationis a match @@ -484,10 +483,10 @@ class FederationTestCase(unittest.HomeserverTestCase): def _create_destination( self, destination: str, - failure_ts: Optional[int] = None, + failure_ts: int | None = None, retry_last_ts: int = 0, retry_interval: int = 0, - last_successful_stream_ordering: Optional[int] = None, + last_successful_stream_ordering: int | None = None, ) -> None: """Create one specific destination @@ -819,7 +818,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): def _create_destination_rooms( self, number_rooms: int, - destination: Optional[str] = None, + destination: str | None = None, ) -> list[str]: """ Create the given number of rooms. The given `destination` homeserver will diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 9afe86b724..447e1098e5 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -20,7 +20,6 @@ # import random import string -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -51,11 +50,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def _new_token( self, - token: Optional[str] = None, - uses_allowed: Optional[int] = None, + token: str | None = None, + uses_allowed: int | None = None, pending: int = 0, completed: int = 0, - expiry_time: Optional[int] = None, + expiry_time: int | None = None, ) -> str: """Helper function to create a token.""" if token is None: diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 40b34f4433..7daf13ad22 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -22,7 +22,6 @@ import json import time import urllib.parse from http import HTTPStatus -from typing import Optional from unittest.mock import AsyncMock, Mock from parameterized import parameterized @@ -2074,7 +2073,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self._set_canonical_alias(room_id_1, "#Room_Alias1:test", self.admin_user_tok) def _search_test( - expected_room_id: Optional[str], + expected_room_id: str | None, search_term: str, expected_http_code: int = 200, ) -> None: diff --git a/tests/rest/admin/test_scheduled_tasks.py b/tests/rest/admin/test_scheduled_tasks.py index 264c62e2de..fb275f6d55 100644 --- a/tests/rest/admin/test_scheduled_tasks.py +++ b/tests/rest/admin/test_scheduled_tasks.py @@ -13,7 +13,7 @@ # # # -from typing import Mapping, Optional +from typing import Mapping from twisted.internet.testing import MemoryReactor @@ -42,17 +42,17 @@ class ScheduledTasksAdminApiTestCase(unittest.HomeserverTestCase): # create and schedule a few tasks async def _test_task( task: ScheduledTask, - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: return TaskStatus.ACTIVE, None, None async def _finished_test_task( task: ScheduledTask, - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: return TaskStatus.COMPLETE, None, None async def _failed_test_task( task: ScheduledTask, - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: return TaskStatus.FAILED, None, "Everything failed" self._task_scheduler.register_action(_test_task, "test_task") diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index a18952983e..3dc7e5dc97 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -19,7 +19,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource @@ -497,7 +496,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertIn("media_length", c) def _order_test( - self, order_type: str, expected_user_list: list[str], dir: Optional[str] = None + self, order_type: str, expected_user_list: list[str], dir: str | None = None ) -> None: """Request the list of users in a certain order. Assert that order is what we expect diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 040b21d471..6d0584fa63 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -27,7 +27,6 @@ import time import urllib.parse from binascii import unhexlify from http import HTTPStatus -from typing import Optional from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class @@ -643,10 +642,10 @@ class UsersListTestCase(unittest.HomeserverTestCase): """Test that searching for a users works correctly""" def _search_test( - expected_user_id: Optional[str], + expected_user_id: str | None, search_term: str, - search_field: Optional[str] = "name", - expected_http_code: Optional[int] = 200, + search_field: str | None = "name", + expected_http_code: int | None = 200, ) -> None: """Search for a user and check that the returned user's id is a match @@ -1185,7 +1184,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) def test_user_type( - expected_user_ids: list[str], not_user_types: Optional[list[str]] = None + expected_user_ids: list[str], not_user_types: list[str] | None = None ) -> None: """Runs a test for the not_user_types param Args: @@ -1262,7 +1261,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) def test_user_type( - expected_user_ids: list[str], not_user_types: Optional[list[str]] = None + expected_user_ids: list[str], not_user_types: list[str] | None = None ) -> None: """Runs a test for the not_user_types param Args: @@ -1374,8 +1373,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): def _order_test( self, expected_user_list: list[str], - order_by: Optional[str], - dir: Optional[str] = None, + order_by: str | None, + dir: str | None = None, ) -> None: """Request the list of users in a certain order. Assert that order is what we expect @@ -3116,7 +3115,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) - def set_user_type(self, user_type: Optional[str]) -> None: + def set_user_type(self, user_type: str | None) -> None: # Set to user_type channel = self.make_request( "PUT", @@ -4213,8 +4212,8 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def _order_test( self, expected_media_list: list[str], - order_by: Optional[str], - dir: Optional[str] = None, + order_by: str | None, + dir: str | None = None, ) -> None: """Request the list of media in a certain order. Assert that order is what we expect diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index b2c1d7ac0a..c3091ce412 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -44,8 +43,8 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): async def check_username( localpart: str, - guest_access_token: Optional[str] = None, - assigned_user_id: Optional[str] = None, + guest_access_token: str | None = None, + assigned_user_id: str | None = None, inhibit_user_in_use_error: bool = False, ) -> None: if localpart == "allowed": diff --git a/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py b/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py index de76334f64..aa251bd78b 100644 --- a/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py +++ b/tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py @@ -13,7 +13,7 @@ # import logging from http import HTTPStatus -from typing import Optional, cast +from typing import cast from twisted.test.proto_helpers import MemoryReactor @@ -455,7 +455,7 @@ class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase): def _do_backpaginate( self, *, from_tok: str, to_tok: str, limit: int, access_token: str - ) -> tuple[JsonDict, Optional[str]]: + ) -> tuple[JsonDict, str | None]: channel = self.make_request( "GET", "/_matrix/client/unstable/io.element.msc4308/thread_subscriptions" @@ -465,7 +465,7 @@ class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) body = channel.json_body - return body, cast(Optional[str], body.get("end")) + return body, cast(str | None, body.get("end")) def _subscribe_to_thread( self, user_id: str, room_id: str, thread_root_id: str diff --git a/tests/rest/client/sliding_sync/test_rooms_timeline.py b/tests/rest/client/sliding_sync/test_rooms_timeline.py index 04a9cd5382..bc23776326 100644 --- a/tests/rest/client/sliding_sync/test_rooms_timeline.py +++ b/tests/rest/client/sliding_sync/test_rooms_timeline.py @@ -12,7 +12,6 @@ # . # import logging -from typing import Optional from parameterized import parameterized_class @@ -66,7 +65,7 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase): self, actual_items: StrSequence, expected_items: StrSequence, - message: Optional[str] = None, + message: str | None = None, ) -> None: """ Like `self.assertListEqual(...)` but with an actually understandable diff message. @@ -103,7 +102,7 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase): room_id: str, actual_event_ids: list[str], expected_event_ids: list[str], - message: Optional[str] = None, + message: str | None = None, ) -> None: """ Like `self.assertListEqual(...)` for event IDs in a room but will give a nicer diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py index 9f4c6bad05..c27a712088 100644 --- a/tests/rest/client/sliding_sync/test_sliding_sync.py +++ b/tests/rest/client/sliding_sync/test_sliding_sync.py @@ -12,7 +12,7 @@ # . # import logging -from typing import Any, Iterable, Literal, Optional +from typing import Any, Iterable, Literal from unittest.mock import AsyncMock from parameterized import parameterized, parameterized_class @@ -81,7 +81,7 @@ class SlidingSyncBase(unittest.HomeserverTestCase): return config def do_sync( - self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str + self, sync_body: JsonDict, *, since: str | None = None, tok: str ) -> tuple[JsonDict, str]: """Do a sliding sync request with given body. @@ -239,8 +239,8 @@ class SlidingSyncBase(unittest.HomeserverTestCase): def _create_remote_invite_room_for_user( self, invitee_user_id: str, - unsigned_invite_room_state: Optional[list[StrippedStateEvent]], - invite_room_id: Optional[str] = None, + unsigned_invite_room_state: list[StrippedStateEvent] | None, + invite_room_id: str | None = None, ) -> str: """ Create a fake invite for a remote room and persist it. diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 03474d7400..ffa96c7840 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -23,7 +23,7 @@ import os import re from email.parser import Parser from http import HTTPStatus -from typing import Any, Optional, Union +from typing import Any from unittest.mock import Mock from twisted.internet.interfaces import IReactorTCP @@ -363,7 +363,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): email: str, client_secret: str, ip: str = "127.0.0.1", - next_link: Optional[str] = None, + next_link: str | None = None, ) -> str: body = {"client_secret": client_secret, "email": email, "send_attempt": 1} if next_link is not None: @@ -384,7 +384,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] - def _validate_token(self, link: str, next_link: Optional[str] = None) -> None: + def _validate_token(self, link: str, next_link: str | None = None) -> None: # Remove the host path = link.replace("https://example.com", "") @@ -1152,9 +1152,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self, email: str, client_secret: str, - next_link: Optional[str] = None, + next_link: str | None = None, expect_code: int = HTTPStatus.OK, - ) -> Optional[str]: + ) -> str | None: """Request a validation token to add an email address to a user's account Args: @@ -1394,10 +1394,10 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): async def post_json( destination: str, path: str, - data: Optional[JsonDict] = None, + data: JsonDict | None = None, *a: Any, **kwa: Any, - ) -> Union[JsonDict, list]: + ) -> JsonDict | list: if destination == "remote": return { "account_statuses": { @@ -1503,11 +1503,11 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): def _test_status( self, - users: Optional[list[str]], + users: list[str] | None, expected_status_code: int = HTTPStatus.OK, - expected_statuses: Optional[dict[str, dict[str, bool]]] = None, - expected_failures: Optional[list[str]] = None, - expected_errcode: Optional[str] = None, + expected_statuses: dict[str, dict[str, bool]] | None = None, + expected_failures: list[str] | None = None, + expected_errcode: str | None = None, ) -> None: """Send a request to the account status endpoint and check that the response matches with what's expected. diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 5955d4b7a2..ffaf0e5a32 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -20,7 +20,7 @@ # import re from http import HTTPStatus -from typing import Any, Optional, Union +from typing import Any from twisted.internet.defer import succeed from twisted.internet.testing import MemoryReactor @@ -90,7 +90,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): self, session: str, expected_post_response: int, - post_session: Optional[str] = None, + post_session: str | None = None, ) -> None: """Get and respond to a fallback recaptcha. Returns the second request.""" if post_session is None: @@ -220,7 +220,7 @@ class UIAuthTests(unittest.HomeserverTestCase): access_token: str, device: str, expected_response: int, - body: Union[bytes, JsonDict] = b"", + body: bytes | JsonDict = b"", ) -> FakeChannel: """Delete an individual device.""" channel = self.make_request( diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 1ebd59b42a..d599351df7 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -26,8 +26,6 @@ from typing import ( Callable, Collection, Literal, - Optional, - Union, ) from unittest.mock import Mock from urllib.parse import urlencode @@ -141,14 +139,11 @@ class TestSpamChecker: async def check_login_for_spam( self, user_id: str, - device_id: Optional[str], - initial_display_name: Optional[str], - request_info: Collection[tuple[Optional[str], str]], - auth_provider_id: Optional[str] = None, - ) -> Union[ - Literal["NOT_SPAM"], - tuple["synapse.module_api.errors.Codes", JsonDict], - ]: + device_id: str | None, + initial_display_name: str | None, + request_info: Collection[tuple[str | None, str]], + auth_provider_id: str | None = None, + ) -> Literal["NOT_SPAM"] | tuple["synapse.module_api.errors.Codes", JsonDict]: return "NOT_SPAM" @@ -165,14 +160,11 @@ class DenyAllSpamChecker: async def check_login_for_spam( self, user_id: str, - device_id: Optional[str], - initial_display_name: Optional[str], - request_info: Collection[tuple[Optional[str], str]], - auth_provider_id: Optional[str] = None, - ) -> Union[ - Literal["NOT_SPAM"], - tuple["synapse.module_api.errors.Codes", JsonDict], - ]: + device_id: str | None, + initial_display_name: str | None, + request_info: Collection[tuple[str | None, str]], + auth_provider_id: str | None = None, + ) -> Literal["NOT_SPAM"] | tuple["synapse.module_api.errors.Codes", JsonDict]: # Return an odd set of values to ensure that they get correctly passed # to the client. return Codes.LIMIT_EXCEEDED, {"extra": "value"} @@ -984,7 +976,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) - def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: + def _make_sso_redirect_request(self, idp_prov: str | None = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect ... possibly specifying an IDP provider @@ -1888,8 +1880,8 @@ class UsernamePickerTestCase(HomeserverTestCase): async def mock_get_file( url: str, output_stream: BinaryIO, - max_size: Optional[int] = None, - headers: Optional[RawHeaders] = None, - is_allowed_content_type: Optional[Callable[[str], bool]] = None, + max_size: int | None = None, + headers: RawHeaders | None = None, + is_allowed_content_type: Callable[[str], bool] | None = None, ) -> tuple[int, dict[bytes, list[bytes]], str, int]: return 0, {b"Content-Type": [b"image/png"]}, "", 200 diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 79f70db8a3..33172f930e 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -24,7 +24,7 @@ import json import os import re import shutil -from typing import Any, BinaryIO, ClassVar, Optional, Sequence +from typing import Any, BinaryIO, ClassVar, Sequence from unittest.mock import MagicMock, Mock, patch from urllib import parse from urllib.parse import quote, urlencode @@ -273,7 +273,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): resolutionReceiver: IResolutionReceiver, hostName: str, portNumber: int = 0, - addressTypes: Optional[Sequence[type[IAddress]]] = None, + addressTypes: Sequence[type[IAddress]] | None = None, transportSemantics: str = "TCP", ) -> IResolutionReceiver: resolution = HostResolution(hostName) @@ -1661,7 +1661,7 @@ class MediaConfigModuleCallbackTestCase(unittest.HomeserverTestCase): async def get_media_config_for_user( self, user_id: str, - ) -> Optional[JsonDict]: + ) -> JsonDict | None: # We echo back the user_id and set a custom upload size. return {"m.upload.size": 1024, "user_id": user_id} @@ -1999,7 +1999,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): "Deferred[Any]", str, str, - Optional[QueryParams], + QueryParams | None, ] ] = [] @@ -2010,7 +2010,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): download_ratelimiter: Ratelimiter, ip_address: Any, max_size: int, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, ignore_backoff: bool = False, follow_redirects: bool = False, @@ -2044,7 +2044,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): download_ratelimiter: Ratelimiter, ip_address: Any, max_size: int, - args: Optional[QueryParams] = None, + args: QueryParams | None = None, retry_on_dns_fail: bool = True, ignore_backoff: bool = False, follow_redirects: bool = False, @@ -2107,7 +2107,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): self.tok = self.login("user", "pass") def _req( - self, content_disposition: Optional[bytes], include_content_type: bool = True + self, content_disposition: bytes | None, include_content_type: bool = True ) -> FakeChannel: channel = self.make_request( "GET", @@ -2418,7 +2418,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): def _test_thumbnail( self, method: str, - expected_body: Optional[bytes], + expected_body: bytes | None, expected_found: bool, unable_to_thumbnail: bool = False, ) -> None: @@ -3012,7 +3012,7 @@ class MediaUploadLimitsModuleOverrides(unittest.HomeserverTestCase): async def _get_media_upload_limits_for_user( self, user_id: str, - ) -> Optional[list[MediaUploadLimit]]: + ) -> list[MediaUploadLimit] | None: # user1 has custom limits if user_id == self.user1: # n.b. we return these in increasing duration order and Synapse will need to sort them correctly @@ -3037,7 +3037,7 @@ class MediaUploadLimitsModuleOverrides(unittest.HomeserverTestCase): sent_bytes: int, attempted_bytes: int, ) -> None: - self.last_media_upload_limit_exceeded: Optional[dict[str, object]] = { + self.last_media_upload_limit_exceeded: dict[str, object] | None = { "user_id": user_id, "limit": limit, "sent_bytes": sent_bytes, diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py index 7e2a63955c..17f8da3e8b 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from unittest.mock import AsyncMock, Mock from twisted.internet.testing import MemoryReactor @@ -155,7 +154,7 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(notification_event_ids, sent_event_ids[2:]) def _request_notifications( - self, from_token: Optional[str], limit: int, expected_count: int + self, from_token: str | None, limit: int, expected_count: int ) -> tuple[list[str], str]: """ Make a request to /notifications to get the latest events to be notified about. diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index aa9b72c65e..023a376ed1 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -24,7 +24,7 @@ import logging import urllib.parse from http import HTTPStatus -from typing import Any, Optional +from typing import Any from canonicaljson import encode_canonical_json @@ -177,7 +177,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name: Optional[str] = None) -> Optional[str]: + def _get_displayname(self, name: str | None = None) -> str | None: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) @@ -187,7 +187,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): # https://github.com/matrix-org/synapse/issues/13137. return channel.json_body.get("displayname") - def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]: + def _get_avatar_url(self, name: str | None = None) -> str | None: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) @@ -846,7 +846,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): self.try_fetch_profile(200, self.requester_tok) def try_fetch_profile( - self, expected_code: int, access_token: Optional[str] = None + self, expected_code: int, access_token: str | None = None ) -> None: self.request_profile(expected_code, access_token=access_token) @@ -862,7 +862,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): self, expected_code: int, url_suffix: str = "", - access_token: Optional[str] = None, + access_token: str | None = None, ) -> None: channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py index 0c1b631b8e..3a6a869c54 100644 --- a/tests/rest/client/test_receipts.py +++ b/tests/rest/client/test_receipts.py @@ -19,7 +19,6 @@ # # from http import HTTPStatus -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -259,7 +258,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) - def _get_read_receipt(self) -> Optional[JsonDict]: + def _get_read_receipt(self) -> JsonDict | None: """Syncs and returns the read receipt.""" # Checks if event is a read receipt diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 88be8748ee..997ca5f9ca 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from parameterized import parameterized @@ -85,8 +84,8 @@ class RedactionsTestCase(HomeserverTestCase): room_id: str, event_id: str, expect_code: int = 200, - with_relations: Optional[list[str]] = None, - content: Optional[JsonDict] = None, + with_relations: list[str] | None = None, + content: JsonDict | None = None, ) -> JsonDict: """Helper function to send a redaction event. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 3912a3c772..2d8ba77a77 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -20,7 +20,7 @@ # import urllib.parse -from typing import Any, Callable, Optional +from typing import Any, Callable from unittest.mock import AsyncMock, patch from twisted.internet.testing import MemoryReactor @@ -79,10 +79,10 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): self, relation_type: str, event_type: str, - key: Optional[str] = None, - content: Optional[dict] = None, - access_token: Optional[str] = None, - parent_id: Optional[str] = None, + key: str | None = None, + content: dict | None = None, + access_token: str | None = None, + parent_id: str | None = None, expected_response_code: int = 200, ) -> FakeChannel: """Helper function to send a relation pointing at `self.parent_id` @@ -845,7 +845,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): ) expected_event_ids.append(channel.json_body["event_id"]) - prev_token: Optional[str] = "" + prev_token: str | None = "" found_event_ids: list[str] = [] for _ in range(20): from_token = "" @@ -1085,7 +1085,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): relation_type: str, assertion_callable: Callable[[JsonDict], None], expected_db_txn_for_event: int, - access_token: Optional[str] = None, + access_token: str | None = None, ) -> None: """ Makes requests to various endpoints which should include bundled aggregations diff --git a/tests/rest/client/test_reporting.py b/tests/rest/client/test_reporting.py index 0fd02f65a6..96697b96d5 100644 --- a/tests/rest/client/test_reporting.py +++ b/tests/rest/client/test_reporting.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -311,7 +310,7 @@ class ReportUserTestCase(unittest.HomeserverTestCase): self.assertEqual(len(rows), 0) def _assert_status( - self, response_status: int, data: JsonDict, user_id: Optional[str] = None + self, response_status: int, data: JsonDict, user_id: str | None = None ) -> None: if user_id is None: user_id = self.target_user_id diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4142aed363..68e09afc54 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -25,7 +25,7 @@ import json from http import HTTPStatus -from typing import Any, Iterable, Literal, Optional, Union +from typing import Any, Iterable, Literal from unittest.mock import AsyncMock, Mock, call, patch from urllib import parse as urlparse @@ -74,7 +74,7 @@ PATH_PREFIX = b"/_matrix/client/api/v1" class RoomBase(unittest.HomeserverTestCase): - rmcreator_id: Optional[str] = None + rmcreator_id: str | None = None servlets = [room.register_servlets, room.register_deprecated_servlets] @@ -959,7 +959,7 @@ class RoomsCreateTestCase(RoomBase): """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. - In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + In this test, we use the more recent API in which callbacks return a `Codes | Literal["NOT_SPAM"]`. """ async def user_may_join_room_codes( @@ -1351,7 +1351,7 @@ class RoomJoinTestCase(RoomBase): """ # Register a dummy callback. Make it allow all room joins for now. - return_value: Union[Literal["NOT_SPAM"], tuple[Codes, dict], Codes] = ( + return_value: Literal["NOT_SPAM"] | tuple[Codes, dict] | Codes = ( synapse.module_api.NOT_SPAM ) @@ -1359,7 +1359,7 @@ class RoomJoinTestCase(RoomBase): userid: str, room_id: str, is_invited: bool, - ) -> Union[Literal["NOT_SPAM"], tuple[Codes, dict], Codes]: + ) -> Literal["NOT_SPAM"] | tuple[Codes, dict] | Codes: return return_value # `spec` argument is needed for this function mock to have `__qualname__`, which @@ -1848,20 +1848,20 @@ class RoomMessagesTestCase(RoomBase): def test_spam_checker_check_event_for_spam( self, name: str, - value: Union[str, bool, Codes, tuple[Codes, JsonDict]], + value: str | bool | Codes | tuple[Codes, JsonDict], expected_code: int, expected_fields: dict, ) -> None: class SpamCheck: - mock_return_value: Union[str, bool, Codes, tuple[Codes, JsonDict], bool] = ( + mock_return_value: str | bool | Codes | tuple[Codes, JsonDict] | bool = ( "NOT_SPAM" ) - mock_content: Optional[JsonDict] = None + mock_content: JsonDict | None = None async def check_event_for_spam( self, event: synapse.events.EventBase, - ) -> Union[str, Codes, tuple[Codes, JsonDict], bool]: + ) -> str | Codes | tuple[Codes, JsonDict] | bool: self.mock_content = event.content return self.mock_return_value @@ -2707,8 +2707,8 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): def make_public_rooms_request( self, - room_types: Optional[list[Union[str, None]]], - instance_id: Optional[str] = None, + room_types: list[str | None] | None, + instance_id: str | None = None, ) -> tuple[list[dict[str, Any]], int]: body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}} if instance_id: @@ -3968,7 +3968,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): """ Test allowing/blocking threepid invites with a spam-check module. - In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + In this test, we use the more recent API in which callbacks return a `Codes | Literal["NOT_SPAM"]`. """ # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we @@ -4532,7 +4532,7 @@ class MSC4293RedactOnBanKickTestCase(unittest.FederatingHomeserverTestCase): original_events: list[EventBase], pulled_events: list[JsonDict], expect_redaction: bool, - reason: Optional[str] = None, + reason: str | None = None, ) -> None: """ Checks a set of original events against a second set of the same events, pulled diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 78fa8f4e1c..0d319dff7e 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -19,7 +19,7 @@ # # import threading -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock, Mock from twisted.internet.testing import MemoryReactor @@ -61,7 +61,7 @@ class LegacyThirdPartyRulesTestModule: async def check_event_allowed( self, event: EventBase, state: StateMap[EventBase] - ) -> Union[bool, dict]: + ) -> bool | dict: return True @staticmethod @@ -150,7 +150,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # types async def check( ev: EventBase, state: StateMap[EventBase] - ) -> tuple[bool, Optional[JsonDict]]: + ) -> tuple[bool, JsonDict | None]: return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) @@ -195,7 +195,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """ class NastyHackException(SynapseError): - def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict: + def error_dict(self, config: HomeServerConfig | None) -> JsonDict: """ This overrides SynapseError's `error_dict` to nastily inject JSON into the error response. @@ -207,7 +207,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # add a callback that will raise our hacky exception async def check( ev: EventBase, state: StateMap[EventBase] - ) -> tuple[bool, Optional[JsonDict]]: + ) -> tuple[bool, JsonDict | None]: raise NastyHackException(429, "message") self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [ @@ -235,7 +235,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # first patch the event checker so that it will try to modify the event async def check( ev: EventBase, state: StateMap[EventBase] - ) -> tuple[bool, Optional[JsonDict]]: + ) -> tuple[bool, JsonDict | None]: ev.content = {"x": "y"} return True, None @@ -260,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] - ) -> tuple[bool, Optional[JsonDict]]: + ) -> tuple[bool, JsonDict | None]: d = ev.get_dict() d["content"] = {"x": "y"} return True, d @@ -295,7 +295,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] - ) -> tuple[bool, Optional[JsonDict]]: + ) -> tuple[bool, JsonDict | None]: d = ev.get_dict() d["content"] = { "msgtype": "m.text", @@ -443,7 +443,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Define a callback that sends a custom event on power levels update. async def test_fn( event: EventBase, state_events: StateMap[EventBase] - ) -> tuple[bool, Optional[JsonDict]]: + ) -> tuple[bool, JsonDict | None]: if event.is_state() and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index da114e505d..ee26492909 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from unittest.mock import patch from twisted.internet.testing import MemoryReactor @@ -56,8 +55,8 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): def _upgrade_room( self, - token: Optional[str] = None, - room_id: Optional[str] = None, + token: str | None = None, + room_id: str | None = None, expire_cache: bool = True, ) -> FakeChannel: if expire_cache: diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index d5c824b291..613c317b8a 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -34,7 +34,6 @@ from typing import ( Literal, Mapping, MutableMapping, - Optional, Sequence, overload, ) @@ -75,42 +74,42 @@ class RestHelper: hs: HomeServer reactor: MemoryReactorClock site: Site - auth_user_id: Optional[str] + auth_user_id: str | None @overload def create_room_as( self, - room_creator: Optional[str] = ..., - is_public: Optional[bool] = ..., - room_version: Optional[str] = ..., - tok: Optional[str] = ..., + room_creator: str | None = ..., + is_public: bool | None = ..., + room_version: str | None = ..., + tok: str | None = ..., expect_code: Literal[200] = ..., - extra_content: Optional[dict] = ..., - custom_headers: Optional[Iterable[tuple[AnyStr, AnyStr]]] = ..., + extra_content: dict | None = ..., + custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = ..., ) -> str: ... @overload def create_room_as( self, - room_creator: Optional[str] = ..., - is_public: Optional[bool] = ..., - room_version: Optional[str] = ..., - tok: Optional[str] = ..., + room_creator: str | None = ..., + is_public: bool | None = ..., + room_version: str | None = ..., + tok: str | None = ..., expect_code: int = ..., - extra_content: Optional[dict] = ..., - custom_headers: Optional[Iterable[tuple[AnyStr, AnyStr]]] = ..., - ) -> Optional[str]: ... + extra_content: dict | None = ..., + custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = ..., + ) -> str | None: ... def create_room_as( self, - room_creator: Optional[str] = None, - is_public: Optional[bool] = True, - room_version: Optional[str] = None, - tok: Optional[str] = None, + room_creator: str | None = None, + is_public: bool | None = True, + room_version: str | None = None, + tok: str | None = None, expect_code: int = HTTPStatus.OK, - extra_content: Optional[dict] = None, - custom_headers: Optional[Iterable[tuple[AnyStr, AnyStr]]] = None, - ) -> Optional[str]: + extra_content: dict | None = None, + custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = None, + ) -> str | None: """ Create a room. @@ -166,11 +165,11 @@ class RestHelper: def invite( self, room: str, - src: Optional[str] = None, - targ: Optional[str] = None, + src: str | None = None, + targ: str | None = None, expect_code: int = HTTPStatus.OK, - tok: Optional[str] = None, - extra_data: Optional[dict] = None, + tok: str | None = None, + extra_data: dict | None = None, ) -> JsonDict: return self.change_membership( room=room, @@ -187,10 +186,10 @@ class RestHelper: room: str, user: str, expect_code: int = HTTPStatus.OK, - tok: Optional[str] = None, - appservice_user_id: Optional[str] = None, - expect_errcode: Optional[Codes] = None, - expect_additional_fields: Optional[dict] = None, + tok: str | None = None, + appservice_user_id: str | None = None, + expect_errcode: Codes | None = None, + expect_additional_fields: dict | None = None, ) -> JsonDict: return self.change_membership( room=room, @@ -206,11 +205,11 @@ class RestHelper: def knock( self, - room: Optional[str] = None, - user: Optional[str] = None, - reason: Optional[str] = None, + room: str | None = None, + user: str | None = None, + reason: str | None = None, expect_code: int = HTTPStatus.OK, - tok: Optional[str] = None, + tok: str | None = None, ) -> None: temp_id = self.auth_user_id self.auth_user_id = user @@ -241,9 +240,9 @@ class RestHelper: def leave( self, room: str, - user: Optional[str] = None, + user: str | None = None, expect_code: int = HTTPStatus.OK, - tok: Optional[str] = None, + tok: str | None = None, ) -> JsonDict: return self.change_membership( room=room, @@ -260,7 +259,7 @@ class RestHelper: src: str, targ: str, expect_code: int = HTTPStatus.OK, - tok: Optional[str] = None, + tok: str | None = None, ) -> JsonDict: """A convenience helper: `change_membership` with `membership` preset to "ban".""" return self.change_membership( @@ -275,15 +274,15 @@ class RestHelper: def change_membership( self, room: str, - src: Optional[str], - targ: Optional[str], + src: str | None, + targ: str | None, membership: str, - extra_data: Optional[dict] = None, - tok: Optional[str] = None, - appservice_user_id: Optional[str] = None, + extra_data: dict | None = None, + tok: str | None = None, + appservice_user_id: str | None = None, expect_code: int = HTTPStatus.OK, - expect_errcode: Optional[str] = None, - expect_additional_fields: Optional[dict] = None, + expect_errcode: str | None = None, + expect_additional_fields: dict | None = None, ) -> JsonDict: """ Send a membership state event into a room. @@ -372,11 +371,11 @@ class RestHelper: def send( self, room_id: str, - body: Optional[str] = None, - txn_id: Optional[str] = None, - tok: Optional[str] = None, + body: str | None = None, + txn_id: str | None = None, + tok: str | None = None, expect_code: int = HTTPStatus.OK, - custom_headers: Optional[Iterable[tuple[AnyStr, AnyStr]]] = None, + custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = None, type: str = "m.room.message", ) -> JsonDict: if body is None: @@ -402,7 +401,7 @@ class RestHelper: "msgtype": "m.text", "body": f"Test event {idx}", }, - tok: Optional[str] = None, + tok: str | None = None, ) -> Sequence[str]: """ Helper to send a handful of sequential events and return their event IDs as a sequence. @@ -424,11 +423,11 @@ class RestHelper: self, room_id: str, type: str, - content: Optional[dict] = None, - txn_id: Optional[str] = None, - tok: Optional[str] = None, + content: dict | None = None, + txn_id: str | None = None, + tok: str | None = None, expect_code: int = HTTPStatus.OK, - custom_headers: Optional[Iterable[tuple[AnyStr, AnyStr]]] = None, + custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = None, ) -> JsonDict: if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -458,7 +457,7 @@ class RestHelper: self, room_id: str, event_id: str, - tok: Optional[str] = None, + tok: str | None = None, expect_code: int = HTTPStatus.OK, ) -> JsonDict: """Request a specific event from the server. @@ -495,8 +494,8 @@ class RestHelper: self, room_id: str, event_type: str, - body: Optional[dict[str, Any]], - tok: Optional[str], + body: dict[str, Any] | None, + tok: str | None, expect_code: int = HTTPStatus.OK, state_key: str = "", method: str = "GET", @@ -574,7 +573,7 @@ class RestHelper: room_id: str, event_type: str, body: dict[str, Any], - tok: Optional[str] = None, + tok: str | None = None, expect_code: int = HTTPStatus.OK, state_key: str = "", ) -> JsonDict: @@ -680,7 +679,7 @@ class RestHelper: fake_server: FakeOidcServer, remote_user_id: str, with_sid: bool = False, - idp_id: Optional[str] = None, + idp_id: str | None = None, expected_status: int = 200, ) -> tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC @@ -751,10 +750,10 @@ class RestHelper: self, fake_server: FakeOidcServer, user_info_dict: JsonDict, - client_redirect_url: Optional[str] = None, - ui_auth_session_id: Optional[str] = None, + client_redirect_url: str | None = None, + ui_auth_session_id: str | None = None, with_sid: bool = False, - idp_id: Optional[str] = None, + idp_id: str | None = None, ) -> tuple[FakeChannel, FakeAuthorizationGrant]: """Perform an OIDC authentication flow via a mock OIDC provider. @@ -878,9 +877,9 @@ class RestHelper: def initiate_sso_login( self, - client_redirect_url: Optional[str], + client_redirect_url: str | None, cookies: MutableMapping[str, str], - idp_id: Optional[str] = None, + idp_id: str | None = None, ) -> str: """Make a request to the login-via-sso redirect endpoint, and return the target diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index c412a19f9b..aaf39e70e4 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -19,7 +19,7 @@ # # from io import BytesIO, StringIO -from typing import Any, Optional, Union +from typing import Any from unittest.mock import Mock import signedjson.key @@ -67,7 +67,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): path: str, ignore_backoff: bool = False, **kwargs: Any, - ) -> Union[JsonDict, list]: + ) -> JsonDict | list: self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -191,8 +191,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 async def post_json( - destination: str, path: str, data: Optional[JsonDict] = None - ) -> Union[JsonDict, list]: + destination: str, path: str, data: JsonDict | None = None + ) -> JsonDict | list: self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py index 5af2e79f45..32e78fc12a 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py @@ -22,7 +22,7 @@ import base64 import json import os import re -from typing import Any, Optional, Sequence +from typing import Any, Sequence from urllib.parse import quote, urlencode from twisted.internet._resolver import HostResolution @@ -135,7 +135,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): resolutionReceiver: IResolutionReceiver, hostName: str, portNumber: int = 0, - addressTypes: Optional[Sequence[type[IAddress]]] = None, + addressTypes: Sequence[type[IAddress]] | None = None, transportSemantics: str = "TCP", ) -> IResolutionReceiver: resolution = HostResolution(hostName) diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py index 2e71e2a797..0e697427bb 100644 --- a/tests/scripts/test_new_matrix_user.py +++ b/tests/scripts/test_new_matrix_user.py @@ -18,7 +18,6 @@ # # -from typing import Optional from unittest.mock import Mock, patch from synapse._scripts.register_new_matrix_user import request_registration @@ -34,14 +33,14 @@ class RegisterTestCase(TestCase): post that MAC. """ - def get(url: str, verify: Optional[bool] = None) -> Mock: + def get(url: str, verify: bool | None = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r def post( - url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + url: str, json: JsonDict | None = None, verify: bool | None = None ) -> Mock: # Make sure we are sent the correct info assert json is not None @@ -85,7 +84,7 @@ class RegisterTestCase(TestCase): If the script fails to fetch a nonce, it throws an error and quits. """ - def get(url: str, verify: Optional[bool] = None) -> Mock: + def get(url: str, verify: bool | None = None) -> Mock: r = Mock() r.status_code = 404 r.reason = "Not Found" @@ -123,14 +122,14 @@ class RegisterTestCase(TestCase): report an error and quit. """ - def get(url: str, verify: Optional[bool] = None) -> Mock: + def get(url: str, verify: bool | None = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r def post( - url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + url: str, json: JsonDict | None = None, verify: bool | None = None ) -> Mock: # Make sure we are sent the correct info assert json is not None diff --git a/tests/server.py b/tests/server.py index ff5c606180..30337f3e38 100644 --- a/tests/server.py +++ b/tests/server.py @@ -119,11 +119,11 @@ R = TypeVar("R") P = ParamSpec("P") # the type of thing that can be passed into `make_request` in the headers list -CustomHeaderType = tuple[Union[str, bytes], Union[str, bytes]] +CustomHeaderType = tuple[str | bytes, str | bytes] # A pre-prepared SQLite DB that is used as a template when creating new SQLite # DB each test run. This dramatically speeds up test set up when using SQLite. -PREPPED_SQLITE_DB_CONN: Optional[LoggingDatabaseConnection] = None +PREPPED_SQLITE_DB_CONN: LoggingDatabaseConnection | None = None class TimedOutException(Exception): @@ -146,9 +146,9 @@ class FakeChannel: _reactor: MemoryReactorClock result: dict = attr.Factory(dict) _ip: str = "127.0.0.1" - _producer: Optional[Union[IPullProducer, IPushProducer]] = None - resource_usage: Optional[ContextResourceUsage] = None - _request: Optional[Request] = None + _producer: IPullProducer | IPushProducer | None = None + resource_usage: ContextResourceUsage | None = None + _request: Request | None = None @property def request(self) -> Request: @@ -206,7 +206,7 @@ class FakeChannel: version: bytes, code: bytes, reason: bytes, - headers: Union[Headers, list[tuple[bytes, bytes]]], + headers: Headers | list[tuple[bytes, bytes]], ) -> None: self.result["version"] = version self.result["code"] = code @@ -248,7 +248,7 @@ class FakeChannel: # TODO This should ensure that the IProducer is an IPushProducer or # IPullProducer, unfortunately twisted.protocols.basic.FileSender does # implement those, but doesn't declare it. - self._producer = cast(Union[IPushProducer, IPullProducer], producer) + self._producer = cast(IPushProducer | IPullProducer, producer) self.producerStreaming = streaming def _produce() -> None: @@ -357,18 +357,18 @@ class FakeSite: def make_request( reactor: MemoryReactorClock, - site: Union[Site, FakeSite], - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, str, JsonDict] = b"", - access_token: Optional[str] = None, + site: Site | FakeSite, + method: bytes | str, + path: bytes | str, + content: bytes | str | JsonDict = b"", + access_token: str | None = None, request: type[Request] = SynapseRequest, shorthand: bool = True, - federation_auth_origin: Optional[bytes] = None, - content_type: Optional[bytes] = None, + federation_auth_origin: bytes | None = None, + content_type: bytes | None = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[Iterable[CustomHeaderType]] = None, + custom_headers: Iterable[CustomHeaderType] | None = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -497,7 +497,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): @implementer(IResolverSimple) class FakeResolver: def getHostByName( - self, name: str, timeout: Optional[Sequence[int]] = None + self, name: str, timeout: Sequence[int] | None = None ) -> "Deferred[str]": if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) @@ -617,7 +617,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): port: int, factory: ClientFactory, timeout: float = 30, - bindAddress: Optional[tuple[str, int]] = None, + bindAddress: tuple[str, int] | None = None, ) -> IConnector: """Fake L{IReactorTCP.connectTCP}.""" @@ -788,7 +788,7 @@ class ThreadPool: def callInThreadWithCallback( self, - onResult: Callable[[bool, Union[Failure, R]], None], + onResult: Callable[[bool, Failure | R], None], function: Callable[P, R], *args: P.args, **kwargs: P.kwargs, @@ -841,17 +841,17 @@ class FakeTransport: """Test reactor """ - _protocol: Optional[IProtocol] = None + _protocol: IProtocol | None = None """The Protocol which is producing data for this transport. Optional, but if set will get called back for connectionLost() notifications etc. """ - _peer_address: Union[IPv4Address, IPv6Address] = attr.Factory( + _peer_address: IPv4Address | IPv6Address = attr.Factory( lambda: address.IPv4Address("TCP", "127.0.0.1", 5678) ) """The value to be returned by getPeer""" - _host_address: Union[IPv4Address, IPv6Address] = attr.Factory( + _host_address: IPv4Address | IPv6Address = attr.Factory( lambda: address.IPv4Address("TCP", "127.0.0.1", 1234) ) """The value to be returned by getHost""" @@ -860,13 +860,13 @@ class FakeTransport: disconnected = False connected = True buffer: bytes = b"" - producer: Optional[IPushProducer] = None + producer: IPushProducer | None = None autoflush: bool = True - def getPeer(self) -> Union[IPv4Address, IPv6Address]: + def getPeer(self) -> IPv4Address | IPv6Address: return self._peer_address - def getHost(self) -> Union[IPv4Address, IPv6Address]: + def getHost(self) -> IPv4Address | IPv6Address: return self._host_address def loseConnection(self) -> None: @@ -955,7 +955,7 @@ class FakeTransport: for x in seq: self.write(x) - def flush(self, maxbytes: Optional[int] = None) -> None: + def flush(self, maxbytes: int | None = None) -> None: if not self.buffer: # nothing to do. Don't write empty buffers: it upsets the # TLSMemoryBIOProtocol @@ -1061,10 +1061,10 @@ def setup_test_homeserver( *, cleanup_func: Callable[[Callable[[], Optional["Deferred[None]"]]], None], server_name: str = "test", - config: Optional[HomeServerConfig] = None, - reactor: Optional[ISynapseReactor] = None, + config: HomeServerConfig | None = None, + reactor: ISynapseReactor | None = None, homeserver_to_use: type[HomeServer] = TestHomeServer, - db_txn_limit: Optional[int] = None, + db_txn_limit: int | None = None, **extra_homeserver_attributes: Any, ) -> HomeServer: """ diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 2cf411e30b..7db710846d 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -23,7 +23,6 @@ from typing import ( Collection, Iterable, Mapping, - Optional, TypeVar, ) @@ -79,7 +78,7 @@ class FakeEvent: id: str, sender: str, type: str, - state_key: Optional[str], + state_key: str | None, content: Mapping[str, object], ): self.node_id = id @@ -525,7 +524,7 @@ class StateTestCase(unittest.TestCase): # EventBuilder. But this is Hard because the relevant attributes are # DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent. # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and - # change this function to accept Union[Event, EventBase, EventBuilder]. + # change this function to accept Event | EventBase | EventBuilder. # This seems reasonable to me, but mypy isn't happy. I think that's # a mypy bug, see https://github.com/python/mypy/issues/5570 # Instead, resort to a type-ignore. @@ -1082,8 +1081,8 @@ class TestStateResolutionStore: self, room_id: str, auth_sets: list[set[str]], - conflicted_state: Optional[set[str]], - additional_backwards_reachable_conflicted_events: Optional[set[str]], + conflicted_state: set[str] | None, + additional_backwards_reachable_conflicted_events: set[str] | None, ) -> "defer.Deferred[StateDifference]": chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] diff --git a/tests/state/test_v21.py b/tests/state/test_v21.py index 7bef3decf0..b17773fb56 100644 --- a/tests/state/test_v21.py +++ b/tests/state/test_v21.py @@ -18,7 +18,7 @@ # # import itertools -from typing import Optional, Sequence +from typing import Sequence from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -357,11 +357,11 @@ class StateResV21TestCase(unittest.HomeserverTestCase): self, room_id: str, state_maps: Sequence[StateMap[str]], - event_map: Optional[dict[str, EventBase]], + event_map: dict[str, EventBase] | None, state_res_store: StateResolutionStoreInterface, ) -> set[str]: _, conflicted_state = _seperate(state_maps) - conflicted_set: Optional[set[str]] = set( + conflicted_set: set[str] | None = set( itertools.chain.from_iterable(conflicted_state.values()) ) if event_map is None: @@ -458,7 +458,7 @@ class StateResV21TestCase(unittest.HomeserverTestCase): resolve_and_check() def persist_event( - self, event: EventBase, state: Optional[StateMap[str]] = None + self, event: EventBase, state: StateMap[str] | None = None ) -> None: """Persist the event, with optional state""" context = self.get_success( @@ -473,12 +473,12 @@ class StateResV21TestCase(unittest.HomeserverTestCase): def create_event( self, event_type: str, - state_key: Optional[str], + state_key: str | None, sender: str, content: dict, auth_events: list[str], - prev_events: Optional[list[str]] = None, - room_id: Optional[str] = None, + prev_events: list[str] | None = None, + room_id: str | None = None, ) -> EventBase: """Short-hand for event_from_pdu_json for fields we typically care about. Tests can override by just calling event_from_pdu_json directly.""" diff --git a/tests/storage/databases/main/test_end_to_end_keys.py b/tests/storage/databases/main/test_end_to_end_keys.py index 35e1e15d66..d21dac6024 100644 --- a/tests/storage/databases/main/test_end_to_end_keys.py +++ b/tests/storage/databases/main/test_end_to_end_keys.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -99,7 +98,7 @@ class EndToEndKeyWorkerStoreTestCase(HomeserverTestCase): def check_timestamp_column( txn: LoggingTransaction, - ) -> list[tuple[JsonDict, Optional[int]]]: + ) -> list[tuple[JsonDict, int | None]]: """Fetch all rows for Alice's keys.""" txn.execute( """ diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index 2d63b52aca..be29e0a7f4 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -19,7 +19,7 @@ # # -from typing import Any, Optional, Sequence +from typing import Any, Sequence from twisted.internet.testing import MemoryReactor @@ -52,7 +52,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): index_name: str, table: str, receipts: dict[tuple[str, str, str], Sequence[dict[str, Any]]], - expected_unique_receipts: dict[tuple[str, str, str], Optional[dict[str, Any]]], + expected_unique_receipts: dict[tuple[str, str, str], dict[str, Any] | None], ) -> None: """Test that the background update to uniqueify non-thread receipts in the given receipts table works properly. diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index d9307154da..c91aad097d 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -19,7 +19,7 @@ # # -from typing import Iterable, Optional +from typing import Iterable from twisted.internet.testing import MemoryReactor @@ -37,7 +37,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): self.user = "@user:test" def _update_ignore_list( - self, *ignored_user_ids: Iterable[str], ignorer_user_id: Optional[str] = None + self, *ignored_user_ids: Iterable[str], ignorer_user_id: str | None = None ) -> None: """Update the account data to block the given users.""" if ignorer_user_id is None: @@ -167,7 +167,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): """Test that ignoring users updates the latest stream ID for the ignored user list account data.""" - def get_latest_ignore_streampos(user_id: str) -> Optional[int]: + def get_latest_ignore_streampos(user_id: str) -> int | None: return self.get_success( self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( user_id, AccountDataTypes.IGNORED_USER_LIST diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 2c1ba9d6c2..bd68f2aaa1 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -19,7 +19,7 @@ # # -from typing import Any, Optional, cast +from typing import Any, cast from unittest.mock import AsyncMock from parameterized import parameterized @@ -104,7 +104,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = cast( - list[tuple[str, str, str, Optional[str], int]], + list[tuple[str, str, str, str | None, int]], self.get_success( self.store.db_pool.simple_select_list( table="user_ips", @@ -135,7 +135,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = cast( - list[tuple[str, str, str, Optional[str], int]], + list[tuple[str, str, str, str | None, int]], self.get_success( self.store.db_pool.simple_select_list( table="user_ips", @@ -184,7 +184,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): else: # Check that the new IP and user agent has not been stored yet db_result = cast( - list[tuple[str, Optional[str], Optional[str], str, Optional[int]]], + list[tuple[str, str | None, str | None, str, int | None]], self.get_success( self.store.db_pool.simple_select_list( table="devices", @@ -266,7 +266,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Check that the new IP and user agent has not been stored yet db_result = cast( - list[tuple[str, Optional[str], Optional[str], str, Optional[int]]], + list[tuple[str, str | None, str | None, str, int | None]], self.get_success( self.store.db_pool.simple_select_list( table="devices", @@ -589,7 +589,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = cast( - list[tuple[str, str, str, Optional[str], int]], + list[tuple[str, str, str, str | None, int]], self.get_success( self.store.db_pool.simple_select_list( table="user_ips", @@ -616,7 +616,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = cast( - list[tuple[str, str, str, Optional[str], int]], + list[tuple[str, str, str, str | None, int]], self.get_success( self.store.db_pool.simple_select_list( table="user_ips", @@ -695,7 +695,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = cast( - list[tuple[str, str, str, Optional[str], int]], + list[tuple[str, str, str, str | None, int]], self.get_success( self.store.db_pool.simple_select_list( table="user_ips", diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index d8c6a1cd04..508f82de4f 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -25,7 +25,6 @@ from typing import ( Mapping, NamedTuple, TypeVar, - Union, cast, ) @@ -931,7 +930,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): room_id = "some_room_id" - def prev_event_format(prev_event_id: str) -> Union[tuple[str, dict], str]: + def prev_event_format(prev_event_id: str) -> tuple[str, dict] | str: """Account for differences in prev_events format across room versions""" if room_version.event_format == EventFormatVersions.ROOM_V1_V2: return prev_event_id, {} diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index ef6c0f2465..d5ed947094 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -19,7 +19,6 @@ # # -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -345,9 +344,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): aggregate_counts[room_id], notif_count + thread_notif_count ) - def _create_event( - highlight: bool = False, thread_id: Optional[str] = None - ) -> str: + def _create_event(highlight: bool = False, thread_id: str | None = None) -> str: content: JsonDict = { "msgtype": "m.text", "body": user_id if highlight else "msg", @@ -527,9 +524,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): aggregate_counts[room_id], notif_count + thread_notif_count ) - def _create_event( - highlight: bool = False, thread_id: Optional[str] = None - ) -> str: + def _create_event(highlight: bool = False, thread_id: str | None = None) -> str: content: JsonDict = { "msgtype": "m.text", "body": user_id if highlight else "msg", @@ -553,7 +548,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): def _rotate() -> None: self.get_success(self.store._rotate_notifs()) - def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + def _mark_read(event_id: str, thread_id: str | None = None) -> None: self.get_success( self.store.insert_receipt( room_id, diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 5c7f814078..7d1c96f97f 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -20,7 +20,6 @@ # import logging -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -168,7 +167,7 @@ class ExtremPruneTestCase(HomeserverTestCase): self.assert_extremities([self.remote_event_1.event_id]) def persist_event( - self, event: EventBase, state: Optional[StateMap[str]] = None + self, event: EventBase, state: StateMap[str] | None = None ) -> None: """Persist the event, with optional state""" context = self.get_success( diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 4846e8cac3..051c5de44d 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from twisted.internet.testing import MemoryReactor @@ -76,7 +75,7 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase): def _create_id_generator( self, instance_name: str = "master", - writers: Optional[list[str]] = None, + writers: list[str] | None = None, ) -> MultiWriterIdGenerator: def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( @@ -113,7 +112,7 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase): self._replicate(instance_name) def _insert_row( - self, instance_name: str, stream_id: int, table: Optional[str] = None + self, instance_name: str, stream_id: int, table: str | None = None ) -> None: """Insert one row as the given instance with given stream_id.""" @@ -144,7 +143,7 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase): self, instance_name: str, number: int, - table: Optional[str] = None, + table: str | None = None, update_stream_table: bool = True, ) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9ea2fa5311..4f97d89f78 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -101,7 +101,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Test each of the registered users is marked as active timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) - # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater. + # Mypy notes that one shouldn't compare int | None to 0 with assertGreater. # Check that timestamp really is an int. assert timestamp is not None self.assertGreater(timestamp, 0) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 10ded391f4..27875dcebb 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -19,7 +19,7 @@ # # -from typing import Collection, Optional +from typing import Collection from twisted.internet.testing import MemoryReactor @@ -101,8 +101,8 @@ class ReceiptTestCase(HomeserverTestCase): ) def get_last_unthreaded_receipt( - self, receipt_types: Collection[str], room_id: Optional[str] = None - ) -> Optional[str]: + self, receipt_types: Collection[str], room_id: str | None = None + ) -> str | None: """ Fetch the event ID for the latest unthreaded receipt in the test room for the test user. diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 2c188b8046..92eb99f1d5 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional, cast +from typing import cast from canonicaljson import json @@ -67,7 +67,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): room: RoomID, user: UserID, membership: str, - extra_content: Optional[JsonDict] = None, + extra_content: JsonDict | None = None, ) -> EventBase: content = {"membership": membership} content.update(extra_content or {}) @@ -248,8 +248,8 @@ class RedactionTestCase(unittest.HomeserverTestCase): async def build( self, prev_event_ids: list[str], - auth_event_ids: Optional[list[str]], - depth: Optional[int] = None, + auth_event_ids: list[str] | None, + depth: int | None = None, ) -> EventBase: built_event = await self._base_builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index c5487d81e6..f8d64e8ce6 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -20,7 +20,7 @@ # # import logging -from typing import Optional, cast +from typing import cast from twisted.internet.testing import MemoryReactor @@ -133,7 +133,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) res = cast( - list[tuple[Optional[str], str]], + list[tuple[str | None, str]], self.get_success( self.store.db_pool.simple_select_list( "room_memberships", @@ -165,7 +165,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) res2 = cast( - list[tuple[Optional[str], str]], + list[tuple[str | None, str]], self.get_success( self.store.db_pool.simple_select_list( "room_memberships", @@ -410,7 +410,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): actual_member_summary: MemberSummary, expected_member_list: list[str], *, - expected_member_count: Optional[int] = None, + expected_member_count: int | None = None, ) -> None: """ Assert that the `MemberSummary` object has the expected members. diff --git a/tests/storage/test_sliding_sync_tables.py b/tests/storage/test_sliding_sync_tables.py index 5cfc1a9c29..db31348a8c 100644 --- a/tests/storage/test_sliding_sync_tables.py +++ b/tests/storage/test_sliding_sync_tables.py @@ -18,7 +18,7 @@ # # import logging -from typing import Optional, cast +from typing import cast import attr from parameterized import parameterized @@ -55,12 +55,12 @@ class _SlidingSyncJoinedRoomResult: # `event.internal_metadata.stream_ordering` is marked optional because it only # exists for persisted events but in the context of these tests, we're only working # with persisted events and we're making comparisons so we will find any mismatch. - event_stream_ordering: Optional[int] - bump_stamp: Optional[int] - room_type: Optional[str] - room_name: Optional[str] + event_stream_ordering: int | None + bump_stamp: int | None + room_type: str | None + room_name: str | None is_encrypted: bool - tombstone_successor_room_id: Optional[str] + tombstone_successor_room_id: str | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -75,12 +75,12 @@ class _SlidingSyncMembershipSnapshotResult: # `event.internal_metadata.stream_ordering` is marked optional because it only # exists for persisted events but in the context of these tests, we're only working # with persisted events and we're making comparisons so we will find any mismatch. - event_stream_ordering: Optional[int] + event_stream_ordering: int | None has_known_state: bool - room_type: Optional[str] - room_name: Optional[str] + room_type: str | None + room_name: str | None is_encrypted: bool - tombstone_successor_room_id: Optional[str] + tombstone_successor_room_id: str | None # Make this default to "not forgotten" because it doesn't apply to many tests and we # don't want to force all of the tests to deal with it. forgotten: bool = False @@ -207,7 +207,7 @@ class SlidingSyncTablesTestCaseBase(HomeserverTestCase): def _create_remote_invite_room_for_user( self, invitee_user_id: str, - unsigned_invite_room_state: Optional[list[StrippedStateEvent]], + unsigned_invite_room_state: list[StrippedStateEvent] | None, ) -> tuple[str, EventBase]: """ Create a fake invite for a remote room and persist it. @@ -2246,7 +2246,7 @@ class SlidingSyncTablesTestCase(SlidingSyncTablesTestCaseBase): ] ) def test_non_join_remote_invite_no_stripped_state( - self, _description: str, stripped_state: Optional[list[StrippedStateEvent]] + self, _description: str, stripped_state: list[StrippedStateEvent] | None ) -> None: """ Test remote invite with no stripped state provided shows up in diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index 3f78308e45..ec6f8c5bfb 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -12,7 +12,6 @@ # . # -from typing import Optional, Union from twisted.internet.testing import MemoryReactor @@ -102,10 +101,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): self, thread_root_id: str, *, - automatic_event_orderings: Optional[EventOrderings], - room_id: Optional[str] = None, - user_id: Optional[str] = None, - ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: + automatic_event_orderings: EventOrderings | None, + room_id: str | None = None, + user_id: str | None = None, + ) -> int | AutomaticSubscriptionConflicted | None: if user_id is None: user_id = self.user_id @@ -124,9 +123,9 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): def _unsubscribe( self, thread_root_id: str, - room_id: Optional[str] = None, - user_id: Optional[str] = None, - ) -> Optional[int]: + room_id: str | None = None, + user_id: str | None = None, + ) -> int | None: if user_id is None: user_id = self.user_id diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 83d3357c65..7b4acd985c 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -19,7 +19,7 @@ # # import re -from typing import Any, Optional, cast +from typing import Any, cast from unittest import mock from unittest.mock import Mock, patch @@ -110,7 +110,7 @@ class GetUserDirectoryTables: thing missing is an unused room_id column. """ rows = cast( - list[tuple[str, Optional[str], Optional[str]]], + list[tuple[str, str | None, str | None]], await self.store.db_pool.simple_select_list( "user_directory", None, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 7737101967..934a2fd307 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -20,7 +20,7 @@ # import unittest -from typing import Any, Collection, Iterable, Optional +from typing import Any, Collection, Iterable from parameterized import parameterized @@ -797,8 +797,8 @@ def _member_event( room_version: RoomVersion, user_id: str, membership: str, - sender: Optional[str] = None, - additional_content: Optional[dict] = None, + sender: str | None = None, + additional_content: dict | None = None, ) -> EventBase: return make_event_from_dict( { @@ -818,7 +818,7 @@ def _member_event( def _join_event( room_version: RoomVersion, user_id: str, - additional_content: Optional[dict] = None, + additional_content: dict | None = None, ) -> EventBase: return _member_event( room_version, @@ -871,7 +871,7 @@ def _build_auth_dict_for_room_version( def _random_state_event( room_version: RoomVersion, sender: str, - auth_events: Optional[Iterable[EventBase]] = None, + auth_events: Iterable[EventBase] | None = None, ) -> EventBase: if auth_events is None: auth_events = [] diff --git a/tests/test_mau.py b/tests/test_mau.py index e535e7dc2e..2d5c4c5d1c 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -20,8 +20,6 @@ """Tests REST events for /rooms paths.""" -from typing import Optional - from twisted.internet.testing import MemoryReactor from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType @@ -313,7 +311,7 @@ class TestMauLimit(unittest.HomeserverTestCase): ) def create_user( - self, localpart: str, token: Optional[str] = None, appservice: bool = False + self, localpart: str, token: str | None = None, appservice: bool = False ) -> str: request_data = { "username": localpart, diff --git a/tests/test_server.py b/tests/test_server.py index e7d3febe3f..2df6bdfa44 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -20,7 +20,7 @@ import re from http import HTTPStatus -from typing import Awaitable, Callable, NoReturn, Optional +from typing import Awaitable, Callable, NoReturn from twisted.internet.defer import Deferred from twisted.web.resource import Resource @@ -309,7 +309,7 @@ class OptionsResourceTests(unittest.TestCase): class WrapHtmlRequestHandlerTests(unittest.TestCase): class TestResource(DirectServeHtmlResource): - callback: Optional[Callable[..., Awaitable[None]]] + callback: Callable[..., Awaitable[None]] | None async def _async_render_GET(self, request: SynapseRequest) -> None: assert self.callback is not None diff --git a/tests/test_state.py b/tests/test_state.py index 6e5a6d845d..7df95ebf8b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -24,7 +24,6 @@ from typing import ( Generator, Iterable, Iterator, - Optional, ) from unittest.mock import AsyncMock, Mock @@ -48,12 +47,12 @@ _next_event_id = 1000 def create_event( - name: Optional[str] = None, - type: Optional[str] = None, - state_key: Optional[str] = None, + name: str | None = None, + type: str | None = None, + state_key: str | None = None, depth: int = 2, - event_id: Optional[str] = None, - prev_events: Optional[list[tuple[str, dict]]] = None, + event_id: str | None = None, + prev_events: list[tuple[str, dict]] | None = None, **kwargs: Any, ) -> EventBase: global _next_event_id @@ -106,7 +105,7 @@ class _DummyStore: return groups async def get_state_ids_for_group( - self, state_group: int, state_filter: Optional[StateFilter] = None + self, state_group: int, state_filter: StateFilter | None = None ) -> MutableStateMap[str]: return self._group_to_state[state_group] @@ -114,9 +113,9 @@ class _DummyStore: self, event_id: str, room_id: str, - prev_group: Optional[int], - delta_ids: Optional[StateMap[str]], - current_state_ids: Optional[StateMap[str]], + prev_group: int | None, + delta_ids: StateMap[str] | None, + current_state_ids: StateMap[str] | None, ) -> int: state_group = self._next_group self._next_group += 1 @@ -147,7 +146,7 @@ class _DummyStore: async def get_state_group_delta( self, name: str - ) -> tuple[Optional[int], Optional[StateMap[str]]]: + ) -> tuple[int | None, StateMap[str] | None]: return None, None def register_events(self, events: Iterable[EventBase]) -> None: diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 9cdb456b1b..a90fc5884d 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Any, Optional +from typing import Any import synapse.server from synapse.api.constants import EventTypes @@ -36,8 +36,8 @@ async def inject_member_event( room_id: str, sender: str, membership: str, - target: Optional[str] = None, - extra_content: Optional[dict] = None, + target: str | None = None, + extra_content: dict | None = None, **kwargs: Any, ) -> EventBase: """Inject a membership event into a room.""" @@ -61,8 +61,8 @@ async def inject_member_event( async def inject_event( hs: synapse.server.HomeServer, - room_version: Optional[str] = None, - prev_event_ids: Optional[list[str]] = None, + room_version: str | None = None, + prev_event_ids: list[str] | None = None, **kwargs: Any, ) -> EventBase: """Inject a generic event into a room @@ -86,8 +86,8 @@ async def inject_event( async def create_event( hs: synapse.server.HomeServer, - room_version: Optional[str] = None, - prev_event_ids: Optional[list[str]] = None, + room_version: str | None = None, + prev_event_ids: list[str] | None = None, **kwargs: Any, ) -> tuple[EventBase, EventContext]: if room_version is None: diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py index aff1626295..a9a4b98df2 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py @@ -20,7 +20,7 @@ # from html.parser import HTMLParser -from typing import Iterable, NoReturn, Optional +from typing import Iterable, NoReturn class TestHtmlParser(HTMLParser): @@ -33,13 +33,13 @@ class TestHtmlParser(HTMLParser): self.links: list[str] = [] # the values of any hidden s: map from name to value - self.hiddens: dict[str, Optional[str]] = {} + self.hiddens: dict[str, str | None] = {} # the values of any radio buttons: map from name to list of values - self.radios: dict[str, list[Optional[str]]] = {} + self.radios: dict[str, list[str | None]] = {} def handle_starttag( - self, tag: str, attrs: Iterable[tuple[str, Optional[str]]] + self, tag: str, attrs: Iterable[tuple[str, str | None]] ) -> None: attr_dict = dict(attrs) if tag == "a": diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index c2d6af029a..837a04077c 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -23,7 +23,7 @@ import base64 import json from hashlib import sha256 -from typing import Any, ContextManager, Optional +from typing import Any, ContextManager from unittest.mock import Mock, patch from urllib.parse import parse_qs @@ -45,8 +45,8 @@ class FakeAuthorizationGrant: client_id: str redirect_uri: str scope: str - nonce: Optional[str] - sid: Optional[str] + nonce: str | None + sid: str | None class FakeOidcServer: @@ -140,7 +140,7 @@ class FakeOidcServer: def get_jwks(self) -> dict: return self._jwks.as_dict() - def get_userinfo(self, access_token: str) -> Optional[dict]: + def get_userinfo(self, access_token: str) -> dict | None: """Given an access token, get the userinfo of the associated session.""" session = self._sessions.get(access_token, None) if session is None: @@ -220,7 +220,7 @@ class FakeOidcServer: scope: str, redirect_uri: str, userinfo: dict, - nonce: Optional[str] = None, + nonce: str | None = None, with_sid: bool = False, ) -> tuple[str, FakeAuthorizationGrant]: """Start an authorization request, and get back the code to use on the authorization endpoint.""" @@ -242,7 +242,7 @@ class FakeOidcServer: return code, grant - def exchange_code(self, code: str) -> Optional[dict[str, Any]]: + def exchange_code(self, code: str) -> dict[str, Any] | None: grant = self._authorization_grants.pop(code, None) if grant is None: return None @@ -296,11 +296,11 @@ class FakeOidcServer: self, method: str, uri: str, - data: Optional[bytes] = None, - headers: Optional[Headers] = None, + data: bytes | None = None, + headers: Headers | None = None, ) -> IResponse: """The override of the SimpleHttpClient#request() method""" - access_token: Optional[str] = None + access_token: str | None = None if headers is None: headers = Headers() @@ -346,7 +346,7 @@ class FakeOidcServer: """Handles requests to the OIDC well-known document.""" return FakeResponse.json(payload=self.get_metadata()) - def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse: + def _get_userinfo_handler(self, access_token: str | None) -> IResponse: """Handles requests to the userinfo endpoint.""" if access_token is None: return FakeResponse(code=401) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 9a8cad6454..06598c29de 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -18,7 +18,6 @@ # # import logging -from typing import Optional from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -693,9 +692,9 @@ async def inject_message_event( hs: HomeServer, room_id: str, sender: str, - body: Optional[str] = "testytest", - soft_failed: Optional[bool] = False, - policy_server_spammy: Optional[bool] = False, + body: str | None = "testytest", + soft_failed: bool | None = False, + policy_server_spammy: bool | None = False, ) -> EventBase: return await inject_event( hs, diff --git a/tests/unittest.py b/tests/unittest.py index 049a92caaa..7ea29364db 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -37,10 +37,8 @@ from typing import ( Iterable, Mapping, NoReturn, - Optional, Protocol, TypeVar, - Union, ) from unittest.mock import Mock, patch @@ -274,7 +272,7 @@ class TestCase(unittest.TestCase): actual_items: AbstractSet[TV], expected_items: AbstractSet[TV], exact: bool = False, - message: Optional[str] = None, + message: str | None = None, ) -> None: """ Assert that all of the `expected_items` are included in the `actual_items`. @@ -573,17 +571,17 @@ class HomeserverTestCase(TestCase): def make_request( self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, str, JsonDict] = b"", - access_token: Optional[str] = None, + method: bytes | str, + path: bytes | str, + content: bytes | str | JsonDict = b"", + access_token: str | None = None, request: type[Request] = SynapseRequest, shorthand: bool = True, - federation_auth_origin: Optional[bytes] = None, - content_type: Optional[bytes] = None, + federation_auth_origin: bytes | None = None, + content_type: bytes | None = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[Iterable[CustomHeaderType]] = None, + custom_headers: Iterable[CustomHeaderType] | None = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -636,10 +634,10 @@ class HomeserverTestCase(TestCase): def setup_test_homeserver( self, - server_name: Optional[str] = None, - config: Optional[JsonDict] = None, - reactor: Optional[ISynapseReactor] = None, - clock: Optional[Clock] = None, + server_name: str | None = None, + config: JsonDict | None = None, + reactor: ISynapseReactor | None = None, + clock: Clock | None = None, **extra_homeserver_attributes: Any, ) -> HomeServer: """ @@ -746,8 +744,8 @@ class HomeserverTestCase(TestCase): self, username: str, password: str, - admin: Optional[bool] = False, - displayname: Optional[str] = None, + admin: bool | None = False, + displayname: str | None = None, ) -> str: """ Register a user. Requires the Admin API be registered. @@ -798,7 +796,7 @@ class HomeserverTestCase(TestCase): username: str, appservice_token: str, inhibit_login: bool = False, - ) -> tuple[str, Optional[str]]: + ) -> tuple[str, str | None]: """Register an appservice user as an application service. Requires the client-facing registration API be registered. @@ -829,9 +827,9 @@ class HomeserverTestCase(TestCase): self, username: str, password: str, - device_id: Optional[str] = None, - additional_request_fields: Optional[dict[str, str]] = None, - custom_headers: Optional[Iterable[CustomHeaderType]] = None, + device_id: str | None = None, + additional_request_fields: dict[str, str] | None = None, + custom_headers: Iterable[CustomHeaderType] | None = None, ) -> str: """ Log in a user, and get an access token. Requires the Login API be registered. @@ -870,7 +868,7 @@ class HomeserverTestCase(TestCase): room_id: str, user: UserID, soft_failed: bool = False, - prev_event_ids: Optional[list[str]] = None, + prev_event_ids: list[str] | None = None, ) -> str: """ Create and send an event. @@ -971,9 +969,9 @@ class FederatingHomeserverTestCase(HomeserverTestCase): self, method: str, path: str, - content: Optional[JsonDict] = None, + content: JsonDict | None = None, await_result: bool = True, - custom_headers: Optional[Iterable[CustomHeaderType]] = None, + custom_headers: Iterable[CustomHeaderType] | None = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """Make an inbound signed federation request to this server @@ -1038,7 +1036,7 @@ def _auth_header_for_request( signing_key: signedjson.key.SigningKey, method: str, path: str, - content: Optional[JsonDict], + content: JsonDict | None, ) -> str: """Build a suitable Authorization header for an outgoing federation request""" request_description: JsonDict = { diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index e27f84fa6d..3fab6c4c57 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -25,7 +25,6 @@ from typing import ( Iterable, Mapping, NoReturn, - Optional, cast, ) from unittest import mock @@ -242,7 +241,7 @@ class DescriptorTestCase(unittest.TestCase): """The wrapped function returns a failure""" class Cls: - result: Optional[Deferred] = None + result: Deferred | None = None call_count = 0 server_name = "test_server" # nb must be called this for @cached _, clock = get_clock() # nb must be called this for @cached diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 8fbee12fb9..a6b7ddf485 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -19,7 +19,7 @@ # import logging import traceback -from typing import Any, Coroutine, NoReturn, Optional, TypeVar +from typing import Any, Coroutine, NoReturn, TypeVar from parameterized import parameterized_class @@ -71,7 +71,7 @@ class ObservableDeferredTest(TestCase): observer1.addBoth(check_called_first) # store the results - results: list[Optional[int]] = [None, None] + results: list[int | None] = [None, None] def check_val(res: int, idx: int) -> int: results[idx] = res @@ -102,7 +102,7 @@ class ObservableDeferredTest(TestCase): observer1.addBoth(check_called_first) # store the results - results: list[Optional[Failure]] = [None, None] + results: list[Failure | None] = [None, None] def check_failure(res: Failure, idx: int) -> None: results[idx] = res diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index ab2e2f6291..b7a23dcd9d 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -22,7 +22,7 @@ from contextlib import contextmanager from os import PathLike from pathlib import Path -from typing import Generator, Optional, Union, cast +from typing import Generator, cast from unittest.mock import patch from packaging.markers import default_environment as packaging_default_environment @@ -44,7 +44,7 @@ class DummyDistribution(metadata.Distribution): def version(self) -> str: return self._version - def locate_file(self, path: Union[str, PathLike]) -> Path: + def locate_file(self, path: str | PathLike) -> Path: raise NotImplementedError() def read_text(self, filename: str) -> None: @@ -63,7 +63,7 @@ distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type] class TestDependencyChecker(TestCase): @contextmanager def mock_installed_package( - self, distribution: Optional[DummyDistribution] + self, distribution: DummyDistribution | None ) -> Generator[None, None, None]: """Pretend that looking up any package yields the given `distribution`. diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index ab0143e605..e6d54ddb8d 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -20,7 +20,7 @@ import threading from io import BytesIO -from typing import BinaryIO, Generator, Optional, cast +from typing import BinaryIO, Generator, cast from unittest.mock import NonCallableMock from zope.interface import implementer @@ -127,7 +127,7 @@ class FileConsumerTests(unittest.TestCase): @implementer(IPullProducer) class DummyPullProducer: def __init__(self) -> None: - self.consumer: Optional[BackgroundFileConsumer] = None + self.consumer: BackgroundFileConsumer | None = None self.deferred: "defer.Deferred[object]" = defer.Deferred() def resumeProducing(self) -> None: @@ -159,7 +159,7 @@ class BlockingBytesWrite: self.closed = False self.write_lock = threading.Lock() - self._notify_write_deferred: Optional[defer.Deferred] = None + self._notify_write_deferred: defer.Deferred | None = None self._number_of_writes = 0 def write(self, write_bytes: bytes) -> None: diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 20281d04fe..d3b123c778 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from twisted.internet import defer from twisted.internet.defer import Deferred @@ -139,7 +138,7 @@ def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float return (reactor.seconds() - start_time) * 1000 -def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings: +def build_rc_config(settings: dict | None = None) -> FederationRatelimitSettings: config_dict = default_config("test") config_dict.update(settings or {}) config = HomeServerConfig() diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index de9e381489..e33ded8a7f 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Optional from twisted.internet.defer import Deferred from twisted.internet.testing import MemoryReactor @@ -43,7 +42,7 @@ class TestTaskScheduler(HomeserverTestCase): async def _test_task( self, task: ScheduledTask - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: # This test task will copy the parameters to the result result = None if task.params: @@ -86,7 +85,7 @@ class TestTaskScheduler(HomeserverTestCase): async def _sleeping_task( self, task: ScheduledTask - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: # Sleep for a second await self.hs.get_clock().sleep(1) return TaskStatus.COMPLETE, None, None @@ -152,7 +151,7 @@ class TestTaskScheduler(HomeserverTestCase): async def _raising_task( self, task: ScheduledTask - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: raise Exception("raising") def test_schedule_raising_task(self) -> None: @@ -166,7 +165,7 @@ class TestTaskScheduler(HomeserverTestCase): async def _resumable_task( self, task: ScheduledTask - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: if task.result and "in_progress" in task.result: return TaskStatus.COMPLETE, {"success": True}, None else: @@ -204,7 +203,7 @@ class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase): async def _test_task( self, task: ScheduledTask - ) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + ) -> tuple[TaskStatus, JsonMapping | None, str | None]: return (TaskStatus.COMPLETE, None, None) @override_config({"run_background_tasks_on": "worker1"}) diff --git a/tests/utils.py b/tests/utils.py index b3d59a0ebe..4052c9a4fb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -25,9 +25,7 @@ import signal from types import FrameType, TracebackType from typing import ( Literal, - Optional, TypeVar, - Union, overload, ) @@ -141,7 +139,7 @@ def default_config(server_name: str, parse: Literal[True]) -> HomeServerConfig: def default_config( server_name: str, parse: bool = False -) -> Union[dict[str, object], HomeServerConfig]: +) -> dict[str, object] | HomeServerConfig: """ Create a reasonable test config. @@ -320,13 +318,13 @@ class test_timeout: ``` """ - def __init__(self, seconds: int, error_message: Optional[str] = None) -> None: + def __init__(self, seconds: int, error_message: str | None = None) -> None: self.error_message = f"Test timed out after {seconds}s" if error_message is not None: self.error_message += f": {error_message}" self.seconds = seconds - def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None: + def handle_timeout(self, signum: int, frame: FrameType | None) -> None: raise TestTimeout(self.error_message) def __enter__(self) -> None: @@ -335,8 +333,8 @@ class test_timeout: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: signal.alarm(0)