Add ability to limit amount uploaded by a user
You can now configure how much media can be uploaded by a user in a given time period.
This commit is contained in:
@@ -119,6 +119,15 @@ def parse_thumbnail_requirements(
|
||||
}
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||
class MediaUploadLimit:
|
||||
"""A limit on the amount of data a user can upload in a given time
|
||||
period."""
|
||||
|
||||
max_bytes: int
|
||||
time_period_ms: int
|
||||
|
||||
|
||||
class ContentRepositoryConfig(Config):
|
||||
section = "media"
|
||||
|
||||
@@ -274,6 +283,13 @@ class ContentRepositoryConfig(Config):
|
||||
|
||||
self.enable_authenticated_media = config.get("enable_authenticated_media", True)
|
||||
|
||||
self.media_upload_limits: List[MediaUploadLimit] = []
|
||||
for limit_config in config.get("media_upload_limits", []):
|
||||
time_period_ms = self.parse_duration(limit_config["time_period"])
|
||||
max_bytes = self.parse_size(limit_config["max_size"])
|
||||
|
||||
self.media_upload_limits.append(MediaUploadLimit(max_bytes, time_period_ms))
|
||||
|
||||
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
|
||||
assert data_dir_path is not None
|
||||
media_store = os.path.join(data_dir_path, "media_store")
|
||||
|
||||
@@ -177,6 +177,13 @@ class MediaRepository:
|
||||
else:
|
||||
self.url_previewer = None
|
||||
|
||||
# We get the media upload limits and sort them in descending order of
|
||||
# time period, so that we can apply some optimizations.
|
||||
self.media_upload_limits = hs.config.media.media_upload_limits
|
||||
self.media_upload_limits.sort(
|
||||
key=lambda limit: limit.time_period_ms, reverse=True
|
||||
)
|
||||
|
||||
def _start_update_recently_accessed(self) -> Deferred:
|
||||
return run_as_background_process(
|
||||
"update_recently_accessed_media", self._update_recently_accessed
|
||||
@@ -327,6 +334,35 @@ class MediaRepository:
|
||||
"Media has been automatically quarantined as it matched existing quarantined media"
|
||||
)
|
||||
|
||||
# Check that the user has not exceeded any of the media upload limits.
|
||||
|
||||
# This is the total size of media uploaded by the user in the last
|
||||
# `time_period_ms` milliseconds, or None if we haven't checked yet.
|
||||
uploaded_media_size: Optional[int] = None
|
||||
|
||||
# Note: the media upload limits are sorted so larger time periods are
|
||||
# first.
|
||||
for limit in self.media_upload_limits:
|
||||
# We only need to check the amount of media uploaded by the user in
|
||||
# this latest (smaller) time period if the amount of media uploaded
|
||||
# in a previous (larger) time period is above the limit.
|
||||
#
|
||||
# This optimization means that in the common case where the user
|
||||
# hasn't uploaded much media, we only need to query the database
|
||||
# once.
|
||||
if (
|
||||
uploaded_media_size is None
|
||||
or uploaded_media_size + content_length > limit.max_bytes
|
||||
):
|
||||
uploaded_media_size = await self.store.get_media_uploaded_size_for_user(
|
||||
user_id=auth_user.to_string(), time_period_ms=limit.time_period_ms
|
||||
)
|
||||
|
||||
if uploaded_media_size + content_length > limit.max_bytes:
|
||||
raise SynapseError(
|
||||
400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED
|
||||
)
|
||||
|
||||
if is_new_media:
|
||||
await self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
|
||||
@@ -1034,3 +1034,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
"local_media_repository",
|
||||
sha256,
|
||||
)
|
||||
|
||||
async def get_media_uploaded_size_for_user(
|
||||
self, user_id: str, time_period_ms: int
|
||||
) -> int:
|
||||
"""Get the total size of media uploaded by a user in the last
|
||||
time_period_ms milliseconds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check.
|
||||
time_period_ms: The time period in milliseconds to consider.
|
||||
|
||||
Returns:
|
||||
The total size of media uploaded by the user in bytes.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT COALESCE(SUM(media_length), 0)
|
||||
FROM local_media_repository
|
||||
WHERE user_id = ? AND created_ts > ?
|
||||
"""
|
||||
|
||||
def _get_media_uploaded_size_for_user_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> int:
|
||||
# Calculate the timestamp for the start of the time period
|
||||
start_ts = self._clock.time_msec() - time_period_ms
|
||||
txn.execute(sql, (user_id, start_ts))
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
return 0
|
||||
return row[0]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_media_uploaded_size_for_user",
|
||||
_get_media_uploaded_size_for_user_txn,
|
||||
)
|
||||
|
||||
@@ -2846,3 +2846,124 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
|
||||
custom_headers=[("If-None-Match", etag)],
|
||||
)
|
||||
self.assertEqual(channel3.code, 404)
|
||||
|
||||
|
||||
class MediaUploadLimits(unittest.HomeserverTestCase):
|
||||
"""
|
||||
This test case simulates a homeserver with media upload limits configured.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
media.register_servlets,
|
||||
login.register_servlets,
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
os.mkdir(self.media_store_path)
|
||||
config["media_store_path"] = self.media_store_path
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
|
||||
# These are the limits that we are testing
|
||||
config["media_upload_limits"] = [
|
||||
{"time_period": "1d", "max_size": "1K"},
|
||||
{"time_period": "1w", "max_size": "3K"},
|
||||
]
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.repo = hs.get_media_repository()
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.store = hs.get_datastores().main
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.tok = self.login("user", "pass")
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
resources = super().create_resource_dict()
|
||||
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
|
||||
return resources
|
||||
|
||||
def upload_media(self, size: int) -> FakeChannel:
|
||||
"""Helper to upload media of a given size."""
|
||||
return self.make_request(
|
||||
"POST",
|
||||
"/_matrix/media/v3/upload",
|
||||
content=b"0" * size,
|
||||
access_token=self.tok,
|
||||
shorthand=False,
|
||||
content_type=b"text/plain",
|
||||
custom_headers=[("Content-Length", str(size))],
|
||||
)
|
||||
|
||||
def test_upload_under_limit(self) -> None:
|
||||
"""Test that uploading media under the limit works."""
|
||||
channel = self.upload_media(67)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
def test_over_day_limit(self) -> None:
|
||||
"""Test that uploading media over the daily limit fails."""
|
||||
channel = self.upload_media(500)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
channel = self.upload_media(800)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_under_daily_limit(self) -> None:
|
||||
"""Test that uploading media under the daily limit fails."""
|
||||
channel = self.upload_media(500)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(60 * 60 * 24) # Advance by one day
|
||||
|
||||
# This will succeed as the daily limit has reset
|
||||
channel = self.upload_media(800)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(60 * 60 * 24) # Advance by one day
|
||||
|
||||
# ... and again
|
||||
channel = self.upload_media(800)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
def test_over_weekly_limit(self) -> None:
|
||||
"""Test that uploading media over the weekly limit fails."""
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(60 * 60 * 24) # Advance by one day
|
||||
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
|
||||
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
|
||||
|
||||
# This will fail as the weekly limit has been exceeded
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Reset the weekly limit by advancing a week
|
||||
self.reactor.advance(7 * 60 * 60 * 24) # Advance by 7 days
|
||||
|
||||
# This will succeed as the weekly limit has reset
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
Reference in New Issue
Block a user