1
0

Use type hinting generics in standard collections (#19046)

aka PEP 585, added in Python 3.9

 - https://peps.python.org/pep-0585/
 - https://docs.astral.sh/ruff/rules/non-pep585-annotation/
This commit is contained in:
Andrew Ferrazzutti
2025-10-22 17:48:19 -04:00
committed by GitHub
parent cba3a814c6
commit fc244bb592
539 changed files with 4599 additions and 5066 deletions

View File

@@ -18,7 +18,7 @@
#
#
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional
from twisted.web.server import Request
@@ -41,7 +41,7 @@ class AdditionalResource(DirectServeJsonResource):
def __init__(
self,
hs: "HomeServer",
handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
handler: Callable[[Request], Awaitable[Optional[tuple[int, Any]]]],
):
"""Initialise AdditionalResource
@@ -56,7 +56,7 @@ class AdditionalResource(DirectServeJsonResource):
super().__init__(clock=hs.get_clock())
self._handler = handler
async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
async def _async_render(self, request: Request) -> Optional[tuple[int, Any]]:
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return await self._handler(request)

View File

@@ -27,12 +27,9 @@ from typing import (
Any,
BinaryIO,
Callable,
Dict,
List,
Mapping,
Optional,
Protocol,
Tuple,
Union,
)
@@ -135,10 +132,10 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the entries can either be Lists or bytes.
RawHeaderValue = Union[
StrSequence,
List[bytes],
List[Union[str, bytes]],
Tuple[bytes, ...],
Tuple[Union[str, bytes], ...],
list[bytes],
list[Union[str, bytes]],
tuple[bytes, ...],
tuple[Union[str, bytes], ...],
]
@@ -205,7 +202,7 @@ class _IPBlockingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
addresses: List[IAddress] = []
addresses: list[IAddress] = []
def _callback() -> None:
has_bad_ip = False
@@ -349,7 +346,7 @@ class BaseHttpClient:
def __init__(
self,
hs: "HomeServer",
treq_args: Optional[Dict[str, Any]] = None,
treq_args: Optional[dict[str, Any]] = None,
):
self.hs = hs
self.server_name = hs.hostname
@@ -479,7 +476,7 @@ class BaseHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
args: Optional[Mapping[str, Union[str, List[str]]]] = None,
args: Optional[Mapping[str, Union[str, list[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@@ -707,7 +704,7 @@ class BaseHttpClient:
max_size: Optional[int] = None,
headers: Optional[RawHeaders] = None,
is_allowed_content_type: Optional[Callable[[str], bool]] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
) -> tuple[int, dict[bytes, list[bytes]], str, int]:
"""GETs a file from a given URL
Args:
url: The URL to GET
@@ -815,7 +812,7 @@ class SimpleHttpClient(BaseHttpClient):
def __init__(
self,
hs: "HomeServer",
treq_args: Optional[Dict[str, Any]] = None,
treq_args: Optional[dict[str, Any]] = None,
ip_allowlist: Optional[IPSet] = None,
ip_blocklist: Optional[IPSet] = None,
use_proxy: bool = False,

View File

@@ -19,7 +19,7 @@
#
import logging
import urllib.parse
from typing import Any, Generator, List, Optional
from typing import Any, Generator, Optional
from urllib.request import ( # type: ignore[attr-defined]
proxy_bypass_environment,
)
@@ -413,7 +413,7 @@ class MatrixHostnameEndpoint:
# to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
async def _resolve_server(self) -> List[Server]:
async def _resolve_server(self) -> list[Server]:
"""Resolves the server name to a list of hosts and ports to attempt to
connect to.
"""

View File

@@ -22,7 +22,7 @@
import logging
import random
import time
from typing import Any, Callable, Dict, List
from typing import Any, Callable
import attr
@@ -34,7 +34,7 @@ from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__)
SERVER_CACHE: Dict[bytes, List["Server"]] = {}
SERVER_CACHE: dict[bytes, list["Server"]] = {}
@attr.s(auto_attribs=True, slots=True, frozen=True)
@@ -58,11 +58,11 @@ class Server:
expires: int = 0
def _sort_server_list(server_list: List[Server]) -> List[Server]:
def _sort_server_list(server_list: list[Server]) -> list[Server]:
"""Given a list of SRV records sort them into priority order and shuffle
each priority with the given weight.
"""
priority_map: Dict[int, List[Server]] = {}
priority_map: dict[int, list[Server]] = {}
for server in server_list:
priority_map.setdefault(server.priority, []).append(server)
@@ -116,14 +116,14 @@ class SrvResolver:
def __init__(
self,
dns_client: Any = client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
cache: dict[bytes, list[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
async def resolve_service(self, service_name: bytes) -> List[Server]:
async def resolve_service(self, service_name: bytes) -> list[Server]:
"""Look up a SRV record
Args:

View File

@@ -22,7 +22,7 @@ import logging
import random
import time
from io import BytesIO
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Optional
import attr
@@ -188,7 +188,7 @@ class WellKnownResolver:
return WellKnownLookupResult(delegated_server=result)
async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
async def _fetch_well_known(self, server_name: bytes) -> tuple[bytes, float]:
"""Actually fetch and parse a .well-known, without checking the cache
Args:
@@ -251,7 +251,7 @@ class WellKnownResolver:
async def _make_well_known_request(
self, server_name: bytes, retry: bool
) -> Tuple[IResponse, bytes]:
) -> tuple[IResponse, bytes]:
"""Make the well known request.
This will retry the request if requested and it fails (with unable
@@ -348,7 +348,7 @@ def _cache_period_from_headers(
return None
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
def _parse_cache_control(headers: Headers) -> dict[bytes, Optional[bytes]]:
cache_controls = {}
cache_control_headers = headers.getRawHeaders(b"cache-control") or []
for hdr in cache_control_headers:

View File

@@ -31,13 +31,10 @@ from typing import (
Any,
BinaryIO,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
TextIO,
Tuple,
TypeVar,
Union,
cast,
@@ -253,7 +250,7 @@ class JsonParser(_BaseJsonParser[JsonDict]):
return isinstance(v, dict)
class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
class LegacyJsonSendParser(_BaseJsonParser[tuple[int, JsonDict]]):
"""Ensure the legacy responses of /send_join & /send_leave are correct."""
def __init__(self) -> None:
@@ -667,7 +664,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
headers_dict: Dict[bytes, List[bytes]] = {}
headers_dict: dict[bytes, list[bytes]] = {}
opentracing.inject_header_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -913,7 +910,7 @@ class MatrixFederationHttpClient:
url_bytes: bytes,
content: Optional[JsonDict] = None,
destination_is: Optional[bytes] = None,
) -> List[bytes]:
) -> list[bytes]:
"""
Builds the Authorization headers for a federation request
Args:
@@ -1291,7 +1288,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: ...
) -> tuple[JsonDict, dict[bytes, list[bytes]]]: ...
@overload
async def get_json_with_headers(
@@ -1304,7 +1301,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ...,
) -> Tuple[T, Dict[bytes, List[bytes]]]: ...
) -> tuple[T, dict[bytes, list[bytes]]]: ...
async def get_json_with_headers(
self,
@@ -1316,7 +1313,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser[T]] = None,
) -> Tuple[Union[JsonDict, T], Dict[bytes, List[bytes]]]:
) -> tuple[Union[JsonDict, T], dict[bytes, list[bytes]]]:
"""GETs some json from the given host homeserver and path
Args:
@@ -1484,7 +1481,7 @@ class MatrixFederationHttpClient:
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
) -> tuple[int, dict[bytes, list[bytes]]]:
"""GETs a file from a given homeserver
Args:
destination: The remote server to send the HTTP request to.
@@ -1645,7 +1642,7 @@ class MatrixFederationHttpClient:
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
) -> tuple[int, dict[bytes, list[bytes]], bytes]:
"""GETs a file from a given homeserver over the federation /download endpoint
Args:
destination: The remote server to send the HTTP request to.

View File

@@ -22,7 +22,7 @@
import json
import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from twisted.internet import protocol
from twisted.internet.interfaces import ITCPTransport
@@ -66,7 +66,7 @@ assert all(header.lower() == header for header in HOP_BY_HOP_HEADERS_LOWERCASE)
def parse_connection_header_value(
connection_header_value: Optional[bytes],
) -> Set[str]:
) -> set[str]:
"""
Parse the `Connection` header to determine which headers we should not be copied
over from the remote response.
@@ -86,7 +86,7 @@ def parse_connection_header_value(
The set of header names that should not be copied over from the remote response.
The keys are lowercased.
"""
extra_headers_to_remove: Set[str] = set()
extra_headers_to_remove: set[str] = set()
if connection_header_value:
extra_headers_to_remove = {
connection_option.decode("ascii").strip().lower()
@@ -140,7 +140,7 @@ class ProxyResource(_AsyncResource):
"Invalid Proxy-Authorization header.", Codes.UNAUTHORIZED
)
async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
async def _async_render(self, request: "SynapseRequest") -> tuple[int, Any]:
uri = urllib.parse.urlparse(request.uri)
assert uri.scheme == b"matrix-federation"

View File

@@ -21,7 +21,7 @@
import logging
import random
import re
from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Collection, Optional, Sequence, Union, cast
from urllib.parse import urlparse
from urllib.request import ( # type: ignore[attr-defined]
proxy_bypass_environment,
@@ -139,7 +139,7 @@ class ProxyAgent(_AgentBase):
else:
self.proxy_reactor = proxy_reactor
self._endpoint_kwargs: Dict[str, Any] = {}
self._endpoint_kwargs: dict[str, Any] = {}
if connectTimeout is not None:
self._endpoint_kwargs["timeout"] = connectTimeout
if bindAddress is not None:
@@ -182,7 +182,7 @@ class ProxyAgent(_AgentBase):
"`federation_proxy_credentials` are required when using `federation_proxy_locations`"
)
endpoints: List[IStreamClientEndpoint] = []
endpoints: list[IStreamClientEndpoint] = []
for federation_proxy_location in federation_proxy_locations:
endpoint: IStreamClientEndpoint
if isinstance(federation_proxy_location, InstanceTcpLocationConfig):
@@ -369,7 +369,7 @@ def http_proxy_endpoint(
timeout: float = 30,
bindAddress: Optional[Union[bytes, str, tuple[Union[bytes, str], int]]] = None,
attemptDelay: Optional[float] = None,
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
) -> tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy
Args:
@@ -418,7 +418,7 @@ def http_proxy_endpoint(
def parse_proxy(
proxy: bytes, default_scheme: bytes = b"http", default_port: int = 1080
) -> Tuple[bytes, bytes, int, Optional[ProxyCredentials]]:
) -> tuple[bytes, bytes, int, Optional[ProxyCredentials]]:
"""
Parse a proxy connection string.
@@ -487,7 +487,7 @@ class _RandomSampleEndpoints:
return run_in_background(self._do_connect, protocol_factory)
async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
failures: List[Failure] = []
failures: list[Failure] = []
for endpoint in random.sample(self._endpoints, k=len(self._endpoints)):
try:
return await endpoint.connect(protocol_factory)

View File

@@ -20,7 +20,7 @@
#
import logging
from typing import Dict, Optional
from typing import Optional
from zope.interface import implementer
@@ -60,7 +60,7 @@ class ReplicationEndpointFactory:
def __init__(
self,
reactor: ISynapseReactor,
instance_map: Dict[str, InstanceLocationConfig],
instance_map: dict[str, InstanceLocationConfig],
context_factory: IPolicyForHTTPS,
) -> None:
self.reactor = reactor
@@ -117,7 +117,7 @@ class ReplicationAgent(_AgentBase):
def __init__(
self,
reactor: ISynapseReactor,
instance_map: Dict[str, InstanceLocationConfig],
instance_map: dict[str, InstanceLocationConfig],
contextFactory: IPolicyForHTTPS,
connectTimeout: Optional[float] = None,
bindAddress: Optional[bytes] = None,

View File

@@ -22,7 +22,7 @@
import logging
import threading
import traceback
from typing import Dict, Mapping, Set, Tuple
from typing import Mapping
from prometheus_client.core import Counter, Histogram
@@ -133,13 +133,13 @@ in_flight_requests_db_sched_duration = Counter(
labelnames=["method", "servlet", SERVER_NAME_LABEL],
)
_in_flight_requests: Set["RequestMetrics"] = set()
_in_flight_requests: set["RequestMetrics"] = set()
# Protects the _in_flight_requests set from concurrent access
_in_flight_requests_lock = threading.Lock()
def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
def _get_in_flight_counts() -> Mapping[tuple[str, ...], int]:
"""Returns a count of all in flight requests by (method, server_name)"""
# Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics
@@ -152,7 +152,7 @@ def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
# Map from (method, name) -> int, the number of in flight requests of that
# type. The key type is Tuple[str, str], but we leave the length unspecified
# for compatability with LaterGauge's annotations.
counts: Dict[Tuple[str, ...], int] = {}
counts: dict[tuple[str, ...], int] = {}
for request_metric in request_metrics:
key = (
request_metric.method,

View File

@@ -33,14 +33,11 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Pattern,
Protocol,
Tuple,
Union,
cast,
)
@@ -267,7 +264,7 @@ def wrap_async_request_handler(
# it is actually called with a SynapseRequest and a kwargs dict for the params,
# but I can't figure out how to represent that.
ServletCallback = Callable[
..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
..., Union[None, Awaitable[None], tuple[int, Any], Awaitable[tuple[int, Any]]]
]
@@ -354,7 +351,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
async def _async_render(
self, request: "SynapseRequest"
) -> Optional[Tuple[int, Any]]:
) -> Optional[tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
@@ -491,7 +488,7 @@ class JsonResource(DirectServeJsonResource):
self.clock = hs.get_clock()
super().__init__(canonical_json, extract_context, clock=self.clock)
# Map of path regex -> method -> callback.
self._routes: Dict[Pattern[str], Dict[bytes, _PathEntry]] = {}
self._routes: dict[Pattern[str], dict[bytes, _PathEntry]] = {}
self.hs = hs
def register_paths(
@@ -527,7 +524,7 @@ class JsonResource(DirectServeJsonResource):
def _get_handler_for_request(
self, request: "SynapseRequest"
) -> Tuple[ServletCallback, str, Dict[str, str]]:
) -> tuple[ServletCallback, str, dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
@@ -556,7 +553,7 @@ class JsonResource(DirectServeJsonResource):
# Huh. No one wanted to handle that? Fiiiiiine.
raise UnrecognizedRequestError(code=404)
async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
async def _async_render(self, request: "SynapseRequest") -> tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
request.is_render_cancellable = is_function_cancellable(callback)
@@ -758,7 +755,7 @@ class _ByteProducer:
# Start producing if `registerProducer` was successful
self.resumeProducing()
def _send_data(self, data: List[bytes]) -> None:
def _send_data(self, data: list[bytes]) -> None:
"""
Send a list of bytes as a chunk of a response.
"""

View File

@@ -27,13 +27,10 @@ import urllib.parse as urlparse
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
overload,
)
@@ -548,7 +545,7 @@ EnumT = TypeVar("EnumT", bound=enum.Enum)
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
E: type[EnumT],
default: EnumT,
) -> EnumT: ...
@@ -557,7 +554,7 @@ def parse_enum(
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
E: type[EnumT],
*,
required: Literal[True],
) -> EnumT: ...
@@ -566,7 +563,7 @@ def parse_enum(
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
E: type[EnumT],
default: Optional[EnumT] = None,
required: bool = False,
) -> Optional[EnumT]:
@@ -637,18 +634,18 @@ def parse_strings_from_args(
*,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]: ...
) -> Optional[list[str]]: ...
@overload
def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: List[str],
default: list[str],
*,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]: ...
) -> list[str]: ...
@overload
@@ -659,29 +656,29 @@ def parse_strings_from_args(
required: Literal[True],
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]: ...
) -> list[str]: ...
@overload
def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[List[str]] = None,
default: Optional[list[str]] = None,
*,
required: bool = False,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]: ...
) -> Optional[list[str]]: ...
def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[List[str]] = None,
default: Optional[list[str]] = None,
required: bool = False,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
) -> Optional[list[str]]:
"""
Parse a string parameter from the request query string list.
@@ -892,7 +889,7 @@ def parse_json_object_from_request(
Model = TypeVar("Model", bound=BaseModel)
def validate_json_object(content: JsonDict, model_type: Type[Model]) -> Model:
def validate_json_object(content: JsonDict, model_type: type[Model]) -> Model:
"""Validate a deserialized JSON object using the given pydantic model.
Raises:
@@ -922,7 +919,7 @@ def validate_json_object(content: JsonDict, model_type: Type[Model]) -> Model:
def parse_and_validate_json_object_from_request(
request: Request, model_type: Type[Model]
request: Request, model_type: type[Model]
) -> Model:
"""Parse a JSON object from the body of a twisted HTTP request, then deserialise and
validate using the given pydantic model.
@@ -988,8 +985,8 @@ class ResolveRoomIdMixin:
self.room_member_handler = hs.get_room_member_handler()
async def resolve_room_id(
self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
) -> Tuple[str, Optional[List[str]]]:
self, room_identifier: str, remote_room_hosts: Optional[list[str]] = None
) -> tuple[str, Optional[list[str]]]:
"""
Resolve a room identifier to a room ID, if necessary.

View File

@@ -22,7 +22,7 @@ import contextlib
import logging
import time
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Generator, Optional, Union
import attr
from zope.interface import implementer
@@ -266,7 +266,7 @@ class SynapseRequest(Request):
return self.method.decode("ascii")
return method
def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
def get_authenticated_entity(self) -> tuple[Optional[str], Optional[str]]:
"""
Get the "authenticated" entity of the request, which might be the user
performing the action, or a user being puppeted by a server admin.
@@ -783,7 +783,7 @@ class SynapseSite(ProxySite):
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
self.connections: List[Protocol] = []
self.connections: list[Protocol] = []
def buildProtocol(self, addr: IAddress) -> SynapseProtocol:
protocol = SynapseProtocol(