Compare commits
21 Commits
v1.99.0
...
squah/add_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e644eae8d | ||
|
|
4976ae50aa | ||
|
|
08acc0c293 | ||
|
|
3f8a59f8a4 | ||
|
|
a89fc72fba | ||
|
|
342a502a1e | ||
|
|
89cb0f140e | ||
|
|
3544cfdaa1 | ||
|
|
d3f75f3c94 | ||
|
|
92b7b17c3d | ||
|
|
6720b8780f | ||
|
|
2bbad2930d | ||
|
|
3d89472339 | ||
|
|
62d3b915a5 | ||
|
|
c3eb1e3358 | ||
|
|
2326bbf099 | ||
|
|
46cdb4bd07 | ||
|
|
5a9991c0f9 | ||
|
|
0dc4178587 | ||
|
|
1ce7dbf42c | ||
|
|
2780bedb51 |
1
changelog.d/12583.misc
Normal file
1
changelog.d/12583.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add `@cancellable` decorator, for use on endpoint methods that can be cancelled when clients disconnect.
|
||||
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
|
||||
|
||||
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
|
||||
from synapse.api.urls import FEDERATION_V1_PREFIX
|
||||
from synapse.http.server import HttpServer, ServletCallback
|
||||
from synapse.http.server import HttpServer, ServletCallback, is_method_cancellable
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import run_in_background
|
||||
@@ -373,6 +373,17 @@ class BaseFederationServlet:
|
||||
if code is None:
|
||||
continue
|
||||
|
||||
if is_method_cancellable(code):
|
||||
# The wrapper added by `self._wrap` will inherit the cancellable flag,
|
||||
# but the wrapper itself does not support cancellation yet.
|
||||
# Once resolved, the cancellation tests in
|
||||
# `tests/federation/transport/server/test__base.py` can be re-enabled.
|
||||
raise Exception(
|
||||
f"{self.__class__.__name__}.on_{method} has been marked as "
|
||||
"cancellable, but federation servlets do not support cancellation "
|
||||
"yet."
|
||||
)
|
||||
|
||||
server.register_paths(
|
||||
method,
|
||||
(pattern,),
|
||||
|
||||
@@ -33,6 +33,7 @@ from typing import (
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@@ -43,6 +44,7 @@ from typing_extensions import Protocol
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, interfaces
|
||||
from twisted.internet.defer import CancelledError
|
||||
from twisted.python import failure
|
||||
from twisted.web import resource
|
||||
from twisted.web.server import NOT_DONE_YET, Request
|
||||
@@ -82,6 +84,61 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# A fictional HTTP status code for requests where the client has disconnected and we
|
||||
# successfully cancelled the request. Used only for logging purposes. Clients will never
|
||||
# observe this code unless cancellations leak across requests or we raise a
|
||||
# `CancelledError` ourselves.
|
||||
# Analogous to nginx's 499 status code:
|
||||
# https://github.com/nginx/nginx/blob/release-1.21.6/src/http/ngx_http_request.h#L128-L134
|
||||
HTTP_STATUS_REQUEST_CANCELLED = 499
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
_cancellable_method_names = frozenset(
|
||||
{
|
||||
# `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
|
||||
# methods
|
||||
"on_GET",
|
||||
"on_PUT",
|
||||
"on_POST",
|
||||
"on_DELETE",
|
||||
# `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
|
||||
# methods
|
||||
"_async_render_GET",
|
||||
"_async_render_PUT",
|
||||
"_async_render_POST",
|
||||
"_async_render_DELETE",
|
||||
"_async_render_OPTIONS",
|
||||
# `ReplicationEndpoint` methods
|
||||
"_handle_request",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def cancellable(method: F) -> F:
|
||||
"""Marks a servlet method as cancellable.
|
||||
|
||||
Usage:
|
||||
class SomeServlet(RestServlet):
|
||||
@cancellable
|
||||
async def on_GET(self, request: SynapseRequest) -> ...:
|
||||
...
|
||||
"""
|
||||
if method.__name__ not in _cancellable_method_names:
|
||||
raise ValueError(
|
||||
"@cancellable decorator can only be applied to servlet methods."
|
||||
)
|
||||
|
||||
method.cancellable = True # type: ignore[attr-defined]
|
||||
return method
|
||||
|
||||
|
||||
def is_method_cancellable(method: Callable[..., Any]) -> bool:
|
||||
"""Checks whether a servlet method is cancellable."""
|
||||
return getattr(method, "cancellable", False)
|
||||
|
||||
|
||||
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||
"""Sends a JSON error response to clients."""
|
||||
@@ -93,6 +150,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||
error_dict = exc.error_dict()
|
||||
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
|
||||
elif f.check(CancelledError):
|
||||
error_code = HTTP_STATUS_REQUEST_CANCELLED
|
||||
error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN}
|
||||
|
||||
if not request._disconnected:
|
||||
logger.error(
|
||||
"Got cancellation before client disconnection from %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
error_code = 500
|
||||
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
||||
@@ -155,6 +223,16 @@ def return_html_error(
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
)
|
||||
elif f.check(CancelledError):
|
||||
code = HTTP_STATUS_REQUEST_CANCELLED
|
||||
msg = "Request cancelled"
|
||||
|
||||
if not request._disconnected:
|
||||
logger.error(
|
||||
"Got cancellation before client disconnection when handling request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
msg = "Internal server error"
|
||||
@@ -223,6 +301,9 @@ class HttpServer(Protocol):
|
||||
If the regex contains groups these gets passed to the callback via
|
||||
an unpacked tuple.
|
||||
|
||||
The callback may be marked with the `@cancellable` decorator, which will
|
||||
cause request processing to be cancelled when clients disconnect early.
|
||||
|
||||
Args:
|
||||
method: The HTTP method to listen to.
|
||||
path_patterns: The regex used to match requests.
|
||||
@@ -253,7 +334,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
|
||||
def render(self, request: SynapseRequest) -> int:
|
||||
"""This gets called by twisted every time someone sends us a request."""
|
||||
defer.ensureDeferred(self._async_render_wrapper(request))
|
||||
request.render_deferred = defer.ensureDeferred(
|
||||
self._async_render_wrapper(request)
|
||||
)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_async_request_handler
|
||||
@@ -289,6 +372,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
|
||||
method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
|
||||
if method_handler:
|
||||
request.is_render_cancellable = is_method_cancellable(method_handler)
|
||||
|
||||
raw_callback_return = method_handler(request)
|
||||
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
@@ -449,6 +534,8 @@ class JsonResource(DirectServeJsonResource):
|
||||
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
|
||||
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
|
||||
|
||||
request.is_render_cancellable = is_method_cancellable(callback)
|
||||
|
||||
# Make sure we have an appropriate name for this handler in prometheus
|
||||
# (rather than the default of JsonResource).
|
||||
request.request_metrics.name = servlet_classname
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IAddress, IReactorTime
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.http import HTTPChannel
|
||||
@@ -91,6 +92,13 @@ class SynapseRequest(Request):
|
||||
# we can't yet create the logcontext, as we don't know the method.
|
||||
self.logcontext: Optional[LoggingContext] = None
|
||||
|
||||
# The `Deferred` to cancel if the client disconnects early. Expected to be set
|
||||
# by `Resource.render`.
|
||||
self.render_deferred: Optional["Deferred[None]"] = None
|
||||
# A boolean indicating whether `_render_deferred` should be cancelled if the
|
||||
# client disconnects early. Expected to be set during `Resource.render`.
|
||||
self.is_render_cancellable = False
|
||||
|
||||
global _next_request_seq
|
||||
self.request_seq = _next_request_seq
|
||||
_next_request_seq += 1
|
||||
@@ -357,7 +365,21 @@ class SynapseRequest(Request):
|
||||
{"event": "client connection lost", "reason": str(reason.value)}
|
||||
)
|
||||
|
||||
if not self._is_processing:
|
||||
if self._is_processing:
|
||||
if self.is_render_cancellable:
|
||||
if self.render_deferred is not None:
|
||||
# Throw a cancellation into the request processing, in the hope
|
||||
# that it will finish up sooner than it normally would.
|
||||
# The `self.processing()` context manager will call
|
||||
# `_finished_processing()` when done.
|
||||
with PreserveLoggingContext():
|
||||
self.render_deferred.cancel()
|
||||
else:
|
||||
logger.error(
|
||||
"Connection from client lost, but have no Deferred to "
|
||||
"cancel even though the request is marked as cancellable."
|
||||
)
|
||||
else:
|
||||
self._finished_processing()
|
||||
|
||||
def _started_processing(self, servlet_name: str) -> None:
|
||||
|
||||
@@ -26,7 +26,8 @@ from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import HttpResponseException, SynapseError
|
||||
from synapse.http import RequestTimedOutError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.server import HttpServer, is_method_cancellable
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.types import JsonDict
|
||||
@@ -310,6 +311,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
url_args = list(self.PATH_ARGS)
|
||||
method = self.METHOD
|
||||
|
||||
if self.CACHE and is_method_cancellable(self._handle_request):
|
||||
raise Exception(
|
||||
f"{self.__class__.__name__} has been marked as cancellable, but CACHE "
|
||||
"is set. The cancellable flag would have no effect."
|
||||
)
|
||||
|
||||
if self.CACHE:
|
||||
url_args.append("txn_id")
|
||||
|
||||
@@ -324,7 +331,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
)
|
||||
|
||||
async def _check_auth_and_handle(
|
||||
self, request: Request, **kwargs: Any
|
||||
self, request: SynapseRequest, **kwargs: Any
|
||||
) -> Tuple[int, JsonDict]:
|
||||
"""Called on new incoming requests when caching is enabled. Checks
|
||||
if there is a cached response for the request and returns that,
|
||||
@@ -340,8 +347,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
if self.CACHE:
|
||||
txn_id = kwargs.pop("txn_id")
|
||||
|
||||
# We ignore the `@cancellable` flag, since cancellation wouldn't interupt
|
||||
# `_handle_request` and `ResponseCache` does not handle cancellation
|
||||
# correctly yet. In particular, there may be issues to do with logging
|
||||
# context lifetimes.
|
||||
|
||||
return await self.response_cache.wrap(
|
||||
txn_id, self._handle_request, request, **kwargs
|
||||
)
|
||||
|
||||
# The `@cancellable` decorator may be applied to `_handle_request`. But we
|
||||
# told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
|
||||
# so we have to set up the cancellable flag ourselves.
|
||||
request.is_render_cancellable = is_method_cancellable(self._handle_request)
|
||||
|
||||
return await self._handle_request(request, **kwargs)
|
||||
|
||||
13
tests/federation/transport/server/__init__.py
Normal file
13
tests/federation/transport/server/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
114
tests/federation/transport/server/test__base.py
Normal file
114
tests/federation/transport/server/test__base.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.federation.transport.server import BaseFederationServlet
|
||||
from synapse.federation.transport.server._base import Authenticator
|
||||
from synapse.http.server import JsonResource, cancellable
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests import unittest
|
||||
from tests.http.server._base import EndpointCancellationTestHelperMixin
|
||||
|
||||
|
||||
class CancellableFederationServlet(BaseFederationServlet):
|
||||
PATH = "/sleep"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hs: HomeServer,
|
||||
authenticator: Authenticator,
|
||||
ratelimiter: FederationRateLimiter,
|
||||
server_name: str,
|
||||
):
|
||||
super().__init__(hs, authenticator, ratelimiter, server_name)
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@cancellable
|
||||
async def on_GET(
|
||||
self, origin: str, content: None, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
async def on_POST(
|
||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
|
||||
class BaseFederationServletCancellationTests(
|
||||
unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
|
||||
):
|
||||
"""Tests for `BaseFederationServlet` cancellation."""
|
||||
|
||||
skip = "`BaseFederationServlet` does not support cancellation yet."
|
||||
|
||||
path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
|
||||
|
||||
def create_test_resource(self):
|
||||
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
||||
resource = JsonResource(self.hs)
|
||||
|
||||
CancellableFederationServlet(
|
||||
hs=self.hs,
|
||||
authenticator=Authenticator(self.hs),
|
||||
ratelimiter=self.hs.get_federation_ratelimiter(),
|
||||
server_name=self.hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
return resource
|
||||
|
||||
def test_cancellable_disconnect(self) -> None:
|
||||
"""Test that handlers with the `@cancellable` flag can be cancelled."""
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET", self.path, await_result=False
|
||||
)
|
||||
|
||||
# Advance past all the rate limiting logic. If we disconnect too early, the
|
||||
# request won't be processed.
|
||||
self.pump()
|
||||
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=True,
|
||||
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
|
||||
)
|
||||
|
||||
def test_uncancellable_disconnect(self) -> None:
|
||||
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
|
||||
channel = self.make_signed_federation_request(
|
||||
"POST",
|
||||
self.path,
|
||||
content={},
|
||||
await_result=False,
|
||||
)
|
||||
|
||||
# Advance past all the rate limiting logic. If we disconnect too early, the
|
||||
# request won't be processed.
|
||||
self.pump()
|
||||
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=False,
|
||||
expected_body={"result": True},
|
||||
)
|
||||
13
tests/http/server/__init__.py
Normal file
13
tests/http/server/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
100
tests/http/server/_base.py
Normal file
100
tests/http/server/_base.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unles4s required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet.error import ConnectionDone
|
||||
|
||||
from synapse.http.server import (
|
||||
HTTP_STATUS_REQUEST_CANCELLED,
|
||||
respond_with_html_bytes,
|
||||
respond_with_json,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel, ThreadedMemoryReactorClock
|
||||
|
||||
|
||||
class EndpointCancellationTestHelperMixin(unittest.TestCase):
|
||||
"""Provides helper methods for testing cancellation of endpoints."""
|
||||
|
||||
def _test_disconnect(
|
||||
self,
|
||||
reactor: ThreadedMemoryReactorClock,
|
||||
channel: FakeChannel,
|
||||
expect_cancellation: bool,
|
||||
expected_body: Union[bytes, JsonDict],
|
||||
expected_code: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Disconnects an in-flight request and checks the response.
|
||||
|
||||
Args:
|
||||
reactor: The twisted reactor running the request handler.
|
||||
channel: The `FakeChannel` for the request.
|
||||
expect_cancellation: `True` if request processing is expected to be
|
||||
cancelled, `False` if the request should run to completion.
|
||||
expected_body: The expected response for the request.
|
||||
expected_code: The expected status code for the request. Defaults to `200`
|
||||
or `499` depending on `expect_cancellation`.
|
||||
"""
|
||||
# Determine the expected status code.
|
||||
if expected_code is None:
|
||||
if expect_cancellation:
|
||||
expected_code = HTTP_STATUS_REQUEST_CANCELLED
|
||||
else:
|
||||
expected_code = HTTPStatus.OK
|
||||
|
||||
request = channel.request
|
||||
self.assertFalse(
|
||||
channel.is_finished(),
|
||||
"Request finished before we could disconnect - "
|
||||
"was `await_result=False` passed to `make_request`?",
|
||||
)
|
||||
|
||||
# We're about to disconnect the request. This also disconnects the channel, so
|
||||
# we have to rely on mocks to extract the response.
|
||||
respond_method: Callable[..., Any]
|
||||
if isinstance(expected_body, bytes):
|
||||
respond_method = respond_with_html_bytes
|
||||
else:
|
||||
respond_method = respond_with_json
|
||||
|
||||
with mock.patch(
|
||||
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
|
||||
) as respond_mock:
|
||||
# Disconnect the request.
|
||||
request.connectionLost(reason=ConnectionDone())
|
||||
|
||||
if expect_cancellation:
|
||||
# An immediate cancellation is expected.
|
||||
respond_mock.assert_called_once()
|
||||
args, _kwargs = respond_mock.call_args
|
||||
code, body = args[1], args[2]
|
||||
self.assertEqual(code, expected_code)
|
||||
self.assertEqual(request.code, expected_code)
|
||||
self.assertEqual(body, expected_body)
|
||||
else:
|
||||
respond_mock.assert_not_called()
|
||||
|
||||
# The handler is expected to run to completion.
|
||||
reactor.pump([1.0])
|
||||
respond_mock.assert_called_once()
|
||||
args, _kwargs = respond_mock.call_args
|
||||
code, body = args[1], args[2]
|
||||
self.assertEqual(code, expected_code)
|
||||
self.assertEqual(request.code, expected_code)
|
||||
self.assertEqual(body, expected_body)
|
||||
@@ -12,16 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import cancellable
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_json_object_from_request,
|
||||
parse_json_value_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
from tests.http.server._base import EndpointCancellationTestHelperMixin
|
||||
|
||||
|
||||
def make_request(content):
|
||||
@@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase):
|
||||
# Test not an object
|
||||
with self.assertRaises(SynapseError):
|
||||
parse_json_object_from_request(make_request(b'["foo"]'))
|
||||
|
||||
|
||||
class CancellableRestServlet(RestServlet):
|
||||
"""A `RestServlet` with a mix of cancellable and uncancellable handlers."""
|
||||
|
||||
PATTERNS = client_patterns("/sleep$")
|
||||
|
||||
def __init__(self, hs: HomeServer):
|
||||
super().__init__()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@cancellable
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
|
||||
class TestRestServletCancellation(
|
||||
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
|
||||
):
|
||||
"""Tests for `RestServlet` cancellation."""
|
||||
|
||||
servlets = [
|
||||
lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
|
||||
]
|
||||
|
||||
def test_cancellable_disconnect(self) -> None:
|
||||
"""Test that handlers with the `@cancellable` flag can be cancelled."""
|
||||
channel = self.make_request("GET", "/sleep", await_result=False)
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=True,
|
||||
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
|
||||
)
|
||||
|
||||
def test_uncancellable_disconnect(self) -> None:
|
||||
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
|
||||
channel = self.make_request("POST", "/sleep", await_result=False)
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=False,
|
||||
expected_body={"result": True},
|
||||
)
|
||||
|
||||
13
tests/replication/http/__init__.py
Normal file
13
tests/replication/http/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
106
tests/replication/http/test__base.py
Normal file
106
tests/replication/http/test__base.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.http.server import JsonResource, cancellable
|
||||
from synapse.replication.http import REPLICATION_PREFIX
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
from tests.http.server._base import EndpointCancellationTestHelperMixin
|
||||
|
||||
|
||||
class CancellableReplicationEndpoint(ReplicationEndpoint):
|
||||
NAME = "cancellable_sleep"
|
||||
PATH_ARGS = ()
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs: HomeServer):
|
||||
super().__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload() -> JsonDict:
|
||||
return {}
|
||||
|
||||
@cancellable
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
|
||||
class UncancellableReplicationEndpoint(ReplicationEndpoint):
|
||||
NAME = "uncancellable_sleep"
|
||||
PATH_ARGS = ()
|
||||
CACHE = False
|
||||
|
||||
def __init__(self, hs: HomeServer):
|
||||
super().__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload() -> JsonDict:
|
||||
return {}
|
||||
|
||||
async def _handle_request( # type: ignore[override]
|
||||
self, request: Request
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
|
||||
class ReplicationEndpointCancellationTestCase(
|
||||
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
|
||||
):
|
||||
"""Tests for `ReplicationEndpoint` cancellation."""
|
||||
|
||||
def create_test_resource(self):
|
||||
"""Overrides `HomeserverTestCase.create_test_resource`."""
|
||||
resource = JsonResource(self.hs)
|
||||
|
||||
CancellableReplicationEndpoint(self.hs).register(resource)
|
||||
UncancellableReplicationEndpoint(self.hs).register(resource)
|
||||
|
||||
return resource
|
||||
|
||||
def test_cancellable_disconnect(self) -> None:
|
||||
"""Test that handlers with the `@cancellable` flag can be cancelled."""
|
||||
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
|
||||
channel = self.make_request("POST", path, await_result=False)
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=True,
|
||||
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
|
||||
)
|
||||
|
||||
def test_uncancellable_disconnect(self) -> None:
|
||||
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
|
||||
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
|
||||
channel = self.make_request("POST", path, await_result=False)
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=False,
|
||||
expected_body={"result": True},
|
||||
)
|
||||
@@ -109,6 +109,17 @@ class FakeChannel:
|
||||
_ip: str = "127.0.0.1"
|
||||
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
|
||||
resource_usage: Optional[ContextResourceUsage] = None
|
||||
_request: Optional[Request] = None
|
||||
|
||||
@property
|
||||
def request(self) -> Request:
|
||||
assert self._request is not None
|
||||
return self._request
|
||||
|
||||
@request.setter
|
||||
def request(self, request: Request) -> None:
|
||||
assert self._request is None
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def json_body(self):
|
||||
@@ -322,6 +333,8 @@ def make_request(
|
||||
channel = FakeChannel(site, reactor, ip=client_ip)
|
||||
|
||||
req = request(channel, site)
|
||||
channel.request = req
|
||||
|
||||
req.content = BytesIO(content)
|
||||
# Twisted expects to be at the end of the content when parsing the request.
|
||||
req.content.seek(0, SEEK_END)
|
||||
|
||||
@@ -13,18 +13,28 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import Tuple
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import Codes, RedirectException, SynapseError
|
||||
from synapse.config.server import parse_listener_def
|
||||
from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.http.server import (
|
||||
DirectServeHtmlResource,
|
||||
DirectServeJsonResource,
|
||||
JsonResource,
|
||||
OptionsResource,
|
||||
cancellable,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest, SynapseSite
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.http.server._base import EndpointCancellationTestHelperMixin
|
||||
from tests.server import (
|
||||
FakeSite,
|
||||
ThreadedMemoryReactorClock,
|
||||
@@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertNotIn("body", channel.result)
|
||||
|
||||
|
||||
class CancellableDirectServeJsonResource(DirectServeJsonResource):
|
||||
def __init__(self, clock: Clock):
|
||||
super().__init__()
|
||||
self.clock = clock
|
||||
|
||||
@cancellable
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, {"result": True}
|
||||
|
||||
|
||||
class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
|
||||
ERROR_TEMPLATE = "{code} {msg}"
|
||||
|
||||
def __init__(self, clock: Clock):
|
||||
super().__init__()
|
||||
self.clock = clock
|
||||
|
||||
@cancellable
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, b"ok"
|
||||
|
||||
async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]:
|
||||
await self.clock.sleep(1.0)
|
||||
return HTTPStatus.OK, b"ok"
|
||||
|
||||
|
||||
class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
|
||||
"""Tests for `DirectServeJsonResource` cancellation."""
|
||||
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
self.clock = Clock(self.reactor)
|
||||
self.resource = CancellableDirectServeJsonResource(self.clock)
|
||||
self.site = FakeSite(self.resource, self.reactor)
|
||||
|
||||
def test_cancellable_disconnect(self) -> None:
|
||||
"""Test that handlers with the `@cancellable` flag can be cancelled."""
|
||||
channel = make_request(
|
||||
self.reactor, self.site, "GET", "/sleep", await_result=False
|
||||
)
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=True,
|
||||
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
|
||||
)
|
||||
|
||||
def test_uncancellable_disconnect(self) -> None:
|
||||
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
|
||||
channel = make_request(
|
||||
self.reactor, self.site, "POST", "/sleep", await_result=False
|
||||
)
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=False,
|
||||
expected_body={"result": True},
|
||||
)
|
||||
|
||||
|
||||
class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
|
||||
"""Tests for `DirectServeHtmlResource` cancellation."""
|
||||
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
self.clock = Clock(self.reactor)
|
||||
self.resource = CancellableDirectServeHtmlResource(self.clock)
|
||||
self.site = FakeSite(self.resource, self.reactor)
|
||||
|
||||
def test_cancellable_disconnect(self) -> None:
|
||||
"""Test that handlers with the `@cancellable` flag can be cancelled."""
|
||||
channel = make_request(
|
||||
self.reactor, self.site, "GET", "/sleep", await_result=False
|
||||
)
|
||||
self._test_disconnect(
|
||||
self.reactor,
|
||||
channel,
|
||||
expect_cancellation=True,
|
||||
expected_body=b"499 Request cancelled",
|
||||
)
|
||||
|
||||
def test_uncancellable_disconnect(self) -> None:
|
||||
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
|
||||
channel = make_request(
|
||||
self.reactor, self.site, "POST", "/sleep", await_result=False
|
||||
)
|
||||
self._test_disconnect(
|
||||
self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
|
||||
)
|
||||
|
||||
@@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||
self.site,
|
||||
method=method,
|
||||
path=path,
|
||||
content=content or "",
|
||||
content=content if content is not None else "",
|
||||
shorthand=False,
|
||||
await_result=await_result,
|
||||
custom_headers=custom_headers,
|
||||
|
||||
Reference in New Issue
Block a user