1
0

Add ability to limit amount uploaded by a user

You can now configure how much media can be uploaded by a user in a
given time period.
This commit is contained in:
Erik Johnston
2025-06-09 10:41:58 +01:00
parent 447910df19
commit eb13e9ead6
4 changed files with 209 additions and 0 deletions

View File

@@ -119,6 +119,15 @@ def parse_thumbnail_requirements(
}
@attr.s(auto_attribs=True, slots=True, frozen=True)
class MediaUploadLimit:
"""A limit on the amount of data a user can upload in a given time
period."""
max_bytes: int
time_period_ms: int
class ContentRepositoryConfig(Config):
section = "media"
@@ -274,6 +283,13 @@ class ContentRepositoryConfig(Config):
self.enable_authenticated_media = config.get("enable_authenticated_media", True)
self.media_upload_limits: List[MediaUploadLimit] = []
for limit_config in config.get("media_upload_limits", []):
time_period_ms = self.parse_duration(limit_config["time_period"])
max_bytes = self.parse_size(limit_config["max_size"])
self.media_upload_limits.append(MediaUploadLimit(max_bytes, time_period_ms))
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
assert data_dir_path is not None
media_store = os.path.join(data_dir_path, "media_store")

View File

@@ -177,6 +177,13 @@ class MediaRepository:
else:
self.url_previewer = None
# We get the media upload limits and sort them in descending order of
# time period, so that we can apply some optimizations.
self.media_upload_limits = hs.config.media.media_upload_limits
self.media_upload_limits.sort(
key=lambda limit: limit.time_period_ms, reverse=True
)
def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
@@ -327,6 +334,35 @@ class MediaRepository:
"Media has been automatically quarantined as it matched existing quarantined media"
)
# Check that the user has not exceeded any of the media upload limits.
# This is the total size of media uploaded by the user in the last
# `time_period_ms` milliseconds, or None if we haven't checked yet.
uploaded_media_size: Optional[int] = None
# Note: the media upload limits are sorted so larger time periods are
# first.
for limit in self.media_upload_limits:
# We only need to check the amount of media uploaded by the user in
# this latest (smaller) time period if the amount of media uploaded
# in a previous (larger) time period is above the limit.
#
# This optimization means that in the common case where the user
# hasn't uploaded much media, we only need to query the database
# once.
if (
uploaded_media_size is None
or uploaded_media_size + content_length > limit.max_bytes
):
uploaded_media_size = await self.store.get_media_uploaded_size_for_user(
user_id=auth_user.to_string(), time_period_ms=limit.time_period_ms
)
if uploaded_media_size + content_length > limit.max_bytes:
raise SynapseError(
400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED
)
if is_new_media:
await self.store.store_local_media(
media_id=media_id,

View File

@@ -1034,3 +1034,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"local_media_repository",
sha256,
)
async def get_media_uploaded_size_for_user(
self, user_id: str, time_period_ms: int
) -> int:
"""Get the total size of media uploaded by a user in the last
time_period_ms milliseconds.
Args:
user_id: The user ID to check.
time_period_ms: The time period in milliseconds to consider.
Returns:
The total size of media uploaded by the user in bytes.
"""
sql = """
SELECT COALESCE(SUM(media_length), 0)
FROM local_media_repository
WHERE user_id = ? AND created_ts > ?
"""
def _get_media_uploaded_size_for_user_txn(
txn: LoggingTransaction,
) -> int:
# Calculate the timestamp for the start of the time period
start_ts = self._clock.time_msec() - time_period_ms
txn.execute(sql, (user_id, start_ts))
row = txn.fetchone()
if row is None:
return 0
return row[0]
return await self.db_pool.runInteraction(
"get_media_uploaded_size_for_user",
_get_media_uploaded_size_for_user_txn,
)

View File

@@ -2846,3 +2846,124 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
custom_headers=[("If-None-Match", etag)],
)
self.assertEqual(channel3.code, 404)
class MediaUploadLimits(unittest.HomeserverTestCase):
"""
This test case simulates a homeserver with media upload limits configured.
"""
servlets = [
media.register_servlets,
login.register_servlets,
admin.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
# These are the limits that we are testing
config["media_upload_limits"] = [
{"time_period": "1d", "max_size": "1K"},
{"time_period": "1w", "max_size": "3K"},
]
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
self.store = hs.get_datastores().main
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def upload_media(self, size: int) -> FakeChannel:
"""Helper to upload media of a given size."""
return self.make_request(
"POST",
"/_matrix/media/v3/upload",
content=b"0" * size,
access_token=self.tok,
shorthand=False,
content_type=b"text/plain",
custom_headers=[("Content-Length", str(size))],
)
def test_upload_under_limit(self) -> None:
"""Test that uploading media under the limit works."""
channel = self.upload_media(67)
self.assertEqual(channel.code, 200)
def test_over_day_limit(self) -> None:
"""Test that uploading media over the daily limit fails."""
channel = self.upload_media(500)
self.assertEqual(channel.code, 200)
channel = self.upload_media(800)
self.assertEqual(channel.code, 400)
def test_under_daily_limit(self) -> None:
"""Test that uploading media under the daily limit fails."""
channel = self.upload_media(500)
self.assertEqual(channel.code, 200)
self.reactor.advance(60 * 60 * 24) # Advance by one day
# This will succeed as the daily limit has reset
channel = self.upload_media(800)
self.assertEqual(channel.code, 200)
self.reactor.advance(60 * 60 * 24) # Advance by one day
# ... and again
channel = self.upload_media(800)
self.assertEqual(channel.code, 200)
def test_over_weekly_limit(self) -> None:
"""Test that uploading media over the weekly limit fails."""
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)
self.reactor.advance(60 * 60 * 24) # Advance by one day
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
# This will fail as the weekly limit has been exceeded
channel = self.upload_media(900)
self.assertEqual(channel.code, 400)
# Reset the weekly limit by advancing a week
self.reactor.advance(7 * 60 * 60 * 24) # Advance by 7 days
# This will succeed as the weekly limit has reset
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)