1
0

disallow-untyped-defs for synapse.replication.http._base

This commit is contained in:
David Robertson
2021-10-07 18:56:47 +01:00
parent 0e68a4b162
commit be2adcf695
3 changed files with 37 additions and 11 deletions

View File

@@ -99,6 +99,9 @@ disallow_untyped_defs = True
[mypy-synapse.rest.*]
disallow_untyped_defs = True
[mypy-synapse.replication.http._base]
disallow_untyped_defs = True
[mypy-synapse.state.*]
disallow_untyped_defs = True

View File

@@ -15,7 +15,17 @@
"""Contains functions for registering clients."""
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
List,
Mapping,
Optional,
Tuple,
)
from prometheus_client import Counter
from typing_extensions import TypedDict
@@ -103,6 +113,7 @@ class RegistrationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self._register_device_client: Callable[..., Awaitable[Mapping[str, Any]]]
if hs.config.worker.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(

View File

@@ -17,14 +17,18 @@ import logging
import re
import urllib
from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
from prometheus_client import Counter, Gauge
from twisted.web.http import Request
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -113,10 +117,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
def _check_auth(self, request) -> None:
def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if auth_headers is None:
raise RuntimeError("No Authorization header.")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
@@ -129,7 +134,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
raise RuntimeError("Invalid Authorization header.")
@abc.abstractmethod
async def _serialize_payload(**kwargs):
async def _serialize_payload(**kwargs: str) -> Dict[str, Any]:
"""Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than
@@ -144,7 +149,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return {}
@abc.abstractmethod
async def _handle_request(self, request, **kwargs):
async def _handle_request(
self, request: Request, **kwargs: str
) -> Tuple[int, JsonDict]:
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
@@ -156,7 +163,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
pass
@classmethod
def make_client(cls, hs):
def make_client(cls, hs: HomeServer) -> Callable[..., Awaitable[JsonDict]]:
"""Create a client that makes requests.
Returns a callable that accepts the same parameters as
@@ -183,7 +190,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(*, instance_name="master", **kwargs):
async def send_request(
*, instance_name: str = "master", **kwargs: str
) -> JsonDict:
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
@@ -207,6 +216,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
txn_id = random_string(10)
url_args.append(txn_id)
request_func: Callable[..., Awaitable[JsonDict]]
if cls.METHOD == "POST":
request_func = client.post_json_get_json
elif cls.METHOD == "PUT":
@@ -264,7 +274,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return send_request
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
"""Called by the server to register this as a handler to the
appropriate path.
"""
@@ -285,7 +295,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
self.__class__.__name__,
)
async def _check_auth_and_handle(self, request, **kwargs):
async def _check_auth_and_handle(
self, request: Request, **kwargs: str
) -> Tuple[int, JsonDict]:
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -301,7 +313,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
txn_id = kwargs.pop("txn_id")
return await self.response_cache.wrap(
txn_id, self._handle_request, request, **kwargs
txn_id, self._handle_request, request, cache_context=False, **kwargs
)
return await self._handle_request(request, **kwargs)