Ratelimiting of remote media downloads (#17256)
This commit is contained in:
@@ -25,7 +25,7 @@ import tempfile
|
||||
from binascii import unhexlify
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from urllib import parse
|
||||
|
||||
import attr
|
||||
@@ -37,9 +37,12 @@ from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import Codes, HttpResponseException
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.events import EventBase
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
@@ -59,6 +62,7 @@ from synapse.util import Clock
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
from tests.test_utils import SMALL_PNG
|
||||
from tests.unittest import override_config
|
||||
from tests.utils import default_config
|
||||
|
||||
|
||||
@@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
destination: str,
|
||||
path: str,
|
||||
output_stream: BinaryIO,
|
||||
download_ratelimiter: Ratelimiter,
|
||||
ip_address: Any,
|
||||
max_size: int,
|
||||
args: Optional[QueryParams] = None,
|
||||
retry_on_dns_fail: bool = True,
|
||||
max_size: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
follow_redirects: bool = False,
|
||||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
||||
@@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||
tok=self.tok,
|
||||
expect_code=400,
|
||||
)
|
||||
|
||||
|
||||
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
os.mkdir(self.media_store_path)
|
||||
config["media_store_path"] = self.media_store_path
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.repo = hs.get_media_repository()
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
# We need to manually set the resource tree to include media, the
|
||||
# default only does `/_matrix/client` APIs.
|
||||
return {"/_matrix/media": self.hs.get_media_repository_resource()}
|
||||
|
||||
# mock actually reading file body
|
||||
def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||
d: Deferred = defer.Deferred()
|
||||
d.callback(31457280)
|
||||
return d
|
||||
|
||||
def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||
d: Deferred = defer.Deferred()
|
||||
d.callback(52428800)
|
||||
return d
|
||||
|
||||
@patch(
|
||||
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||
read_body_with_max_size_30MiB,
|
||||
)
|
||||
def test_download_ratelimit_default(self) -> None:
|
||||
"""
|
||||
Test remote media download ratelimiting against default configuration - 500MB bucket
|
||||
and 87kb/second drain rate
|
||||
"""
|
||||
|
||||
# mock out actually sending the request, returns a 30MiB response
|
||||
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||
resp = MagicMock(spec=IResponse)
|
||||
resp.code = 200
|
||||
resp.length = 31457280
|
||||
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||
resp.phrase = b"OK"
|
||||
return resp
|
||||
|
||||
self.client._send_request = _send_request # type: ignore
|
||||
|
||||
# first request should go through
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel.code == 200
|
||||
|
||||
# next 15 should go through
|
||||
for i in range(15):
|
||||
channel2 = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel2.code == 200
|
||||
|
||||
# 17th will hit ratelimit
|
||||
channel3 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel3.code == 429
|
||||
|
||||
# however, a request from a different IP will go through
|
||||
channel4 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||
shorthand=False,
|
||||
client_ip="187.233.230.159",
|
||||
)
|
||||
assert channel4.code == 200
|
||||
|
||||
# at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
|
||||
# 30MiB download is authorized - The last download was blocked at 503,316,480.
|
||||
# The next download will be authorized when bucket hits 492,830,720
|
||||
# (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
|
||||
# needs to drain before another download will be authorized, that will take ~=
|
||||
# 2 minutes (10,485,760/89,088/60)
|
||||
self.reactor.pump([2.0 * 60.0])
|
||||
|
||||
# enough has drained and next request goes through
|
||||
channel5 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel5.code == 200
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"remote_media_download_per_second": "50M",
|
||||
"remote_media_download_burst_count": "50M",
|
||||
}
|
||||
)
|
||||
@patch(
|
||||
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||
read_body_with_max_size_50MiB,
|
||||
)
|
||||
def test_download_rate_limit_config(self) -> None:
|
||||
"""
|
||||
Test that download rate limit config options are correctly picked up and applied
|
||||
"""
|
||||
|
||||
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||
resp = MagicMock(spec=IResponse)
|
||||
resp.code = 200
|
||||
resp.length = 52428800
|
||||
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||
resp.phrase = b"OK"
|
||||
return resp
|
||||
|
||||
self.client._send_request = _send_request # type: ignore
|
||||
|
||||
# first request should go through
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel.code == 200
|
||||
|
||||
# immediate second request should fail
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel.code == 429
|
||||
|
||||
# advance half a second
|
||||
self.reactor.pump([0.5])
|
||||
|
||||
# request still fails
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel.code == 429
|
||||
|
||||
# advance another half second
|
||||
self.reactor.pump([0.5])
|
||||
|
||||
# enough has drained from bucket and request is successful
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel.code == 200
|
||||
|
||||
@patch(
|
||||
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||
read_body_with_max_size_30MiB,
|
||||
)
|
||||
def test_download_ratelimit_max_size_sub(self) -> None:
|
||||
"""
|
||||
Test that if no content-length is provided, the default max size is applied instead
|
||||
"""
|
||||
|
||||
# mock out actually sending the request
|
||||
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||
resp = MagicMock(spec=IResponse)
|
||||
resp.code = 200
|
||||
resp.length = UNKNOWN_LENGTH
|
||||
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||
resp.phrase = b"OK"
|
||||
return resp
|
||||
|
||||
self.client._send_request = _send_request # type: ignore
|
||||
|
||||
# ten requests should go through using the max size (500MB/50MB)
|
||||
for i in range(10):
|
||||
channel2 = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel2.code == 200
|
||||
|
||||
# eleventh will hit ratelimit
|
||||
channel3 = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
|
||||
shorthand=False,
|
||||
)
|
||||
assert channel3.code == 429
|
||||
|
||||
Reference in New Issue
Block a user