From eb13e9ead604d54d2e4c0cf55522232ddb1e9bc1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Jun 2025 10:41:58 +0100 Subject: [PATCH] Add ability to limit amount uploaded by a user You can now configure how much media can be uploaded by a user in a given time period. --- synapse/config/repository.py | 16 +++ synapse/media/media_repository.py | 36 ++++++ .../databases/main/media_repository.py | 36 ++++++ tests/rest/client/test_media.py | 121 ++++++++++++++++++ 4 files changed, 209 insertions(+) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fc5a90c85a..e6a5064c16 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -119,6 +119,15 @@ def parse_thumbnail_requirements( } +@attr.s(auto_attribs=True, slots=True, frozen=True) +class MediaUploadLimit: + """A limit on the amount of data a user can upload in a given time + period.""" + + max_bytes: int + time_period_ms: int + + class ContentRepositoryConfig(Config): section = "media" @@ -274,6 +283,13 @@ class ContentRepositoryConfig(Config): self.enable_authenticated_media = config.get("enable_authenticated_media", True) + self.media_upload_limits: List[MediaUploadLimit] = [] + for limit_config in config.get("media_upload_limits", []): + time_period_ms = self.parse_duration(limit_config["time_period"]) + max_bytes = self.parse_size(limit_config["max_size"]) + + self.media_upload_limits.append(MediaUploadLimit(max_bytes, time_period_ms)) + def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str: assert data_dir_path is not None media_store = os.path.join(data_dir_path, "media_store") diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 2a9d6ec11e..8b8af05061 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -177,6 +177,13 @@ class MediaRepository: else: self.url_previewer = None + # We get the media upload limits and sort them in descending order of + # time period, so that we can apply some optimizations. + self.media_upload_limits = hs.config.media.media_upload_limits + self.media_upload_limits.sort( + key=lambda limit: limit.time_period_ms, reverse=True + ) + def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( "update_recently_accessed_media", self._update_recently_accessed @@ -327,6 +334,35 @@ class MediaRepository: "Media has been automatically quarantined as it matched existing quarantined media" ) + # Check that the user has not exceeded any of the media upload limits. + + # This is the total size of media uploaded by the user in the last + # `time_period_ms` milliseconds, or None if we haven't checked yet. + uploaded_media_size: Optional[int] = None + + # Note: the media upload limits are sorted so larger time periods are + # first. + for limit in self.media_upload_limits: + # We only need to check the amount of media uploaded by the user in + # this latest (smaller) time period if the amount of media uploaded + # in a previous (larger) time period is above the limit. + # + # This optimization means that in the common case where the user + # hasn't uploaded much media, we only need to query the database + # once. + if ( + uploaded_media_size is None + or uploaded_media_size + content_length > limit.max_bytes + ): + uploaded_media_size = await self.store.get_media_uploaded_size_for_user( + user_id=auth_user.to_string(), time_period_ms=limit.time_period_ms + ) + + if uploaded_media_size + content_length > limit.max_bytes: + raise SynapseError( + 400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED + ) + if is_new_media: await self.store.store_local_media( media_id=media_id, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 04866524e3..f726846e57 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -1034,3 +1034,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "local_media_repository", sha256, ) + + async def get_media_uploaded_size_for_user( + self, user_id: str, time_period_ms: int + ) -> int: + """Get the total size of media uploaded by a user in the last + time_period_ms milliseconds. + + Args: + user_id: The user ID to check. + time_period_ms: The time period in milliseconds to consider. + + Returns: + The total size of media uploaded by the user in bytes. + """ + + sql = """ + SELECT COALESCE(SUM(media_length), 0) + FROM local_media_repository + WHERE user_id = ? AND created_ts > ? + """ + + def _get_media_uploaded_size_for_user_txn( + txn: LoggingTransaction, + ) -> int: + # Calculate the timestamp for the start of the time period + start_ts = self._clock.time_msec() - time_period_ms + txn.execute(sql, (user_id, start_ts)) + row = txn.fetchone() + if row is None: + return 0 + return row[0] + + return await self.db_pool.runInteraction( + "get_media_uploaded_size_for_user", + _get_media_uploaded_size_for_user_txn, + ) diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 83aa6a280b..7aa1f2406c 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -2846,3 +2846,124 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase): custom_headers=[("If-None-Match", etag)], ) self.assertEqual(channel3.code, 404) + + +class MediaUploadLimits(unittest.HomeserverTestCase): + """ + This test case simulates a homeserver with media upload limits configured. + """ + + servlets = [ + media.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + # These are the limits that we are testing + config["media_upload_limits"] = [ + {"time_period": "1d", "max_size": "1K"}, + {"time_period": "1w", "max_size": "3K"}, + ] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.repo = hs.get_media_repository() + self.client = hs.get_federation_http_client() + self.store = hs.get_datastores().main + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + def upload_media(self, size: int) -> FakeChannel: + """Helper to upload media of a given size.""" + return self.make_request( + "POST", + "/_matrix/media/v3/upload", + content=b"0" * size, + access_token=self.tok, + shorthand=False, + content_type=b"text/plain", + custom_headers=[("Content-Length", str(size))], + ) + + def test_upload_under_limit(self) -> None: + """Test that uploading media under the limit works.""" + channel = self.upload_media(67) + self.assertEqual(channel.code, 200) + + def test_over_day_limit(self) -> None: + """Test that uploading media over the daily limit fails.""" + channel = self.upload_media(500) + self.assertEqual(channel.code, 200) + + channel = self.upload_media(800) + self.assertEqual(channel.code, 400) + + def test_under_daily_limit(self) -> None: + """Test that uploading media under the daily limit fails.""" + channel = self.upload_media(500) + self.assertEqual(channel.code, 200) + + self.reactor.advance(60 * 60 * 24) # Advance by one day + + # This will succeed as the daily limit has reset + channel = self.upload_media(800) + self.assertEqual(channel.code, 200) + + self.reactor.advance(60 * 60 * 24) # Advance by one day + + # ... and again + channel = self.upload_media(800) + self.assertEqual(channel.code, 200) + + def test_over_weekly_limit(self) -> None: + """Test that uploading media over the weekly limit fails.""" + channel = self.upload_media(900) + self.assertEqual(channel.code, 200) + + self.reactor.advance(60 * 60 * 24) # Advance by one day + + channel = self.upload_media(900) + self.assertEqual(channel.code, 200) + + self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day + + channel = self.upload_media(900) + self.assertEqual(channel.code, 200) + + self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day + + # This will fail as the weekly limit has been exceeded + channel = self.upload_media(900) + self.assertEqual(channel.code, 400) + + # Reset the weekly limit by advancing a week + self.reactor.advance(7 * 60 * 60 * 24) # Advance by 7 days + + # This will succeed as the weekly limit has reset + channel = self.upload_media(900) + self.assertEqual(channel.code, 200)