diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 07827cf95b..10e3d37d4e 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -824,7 +824,7 @@ class SsoHandler: return True # store it in media repository - avatar_mxc_url = await self._media_repo.create_content( + avatar_mxc_url = await self._media_repo.create_or_update_content( media_type=headers[b"Content-Type"][0].decode("utf-8"), upload_name=upload_name, content=picture, diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 18c5a8ecec..2a9d6ec11e 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -285,63 +285,16 @@ class MediaRepository: raise NotFoundError("Media ID has expired") @trace - async def update_content( - self, - media_id: str, - media_type: str, - upload_name: Optional[str], - content: IO, - content_length: int, - auth_user: UserID, - ) -> None: - """Update the content of the given media ID. - - Args: - media_id: The media ID to replace. - media_type: The content type of the file. - upload_name: The name of the file, if provided. - content: A file like object that is the content to store - content_length: The length of the content - auth_user: The user_id of the uploader - """ - file_info = FileInfo(server_name=None, file_id=media_id) - sha256reader = SHA256TransparentIOReader(content) - # This implements all of IO as it has a passthrough - fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) - sha256 = sha256reader.hexdigest() - should_quarantine = await self.store.get_is_hash_quarantined(sha256) - logger.info("Stored local media in file %r", fname) - - if should_quarantine: - logger.warn( - "Media has been automatically quarantined as it matched existing quarantined media" - ) - - await self.store.update_local_media( - media_id=media_id, - media_type=media_type, - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - sha256=sha256, - quarantined_by="system" if should_quarantine else None, - ) - - try: - await self._generate_thumbnails(None, media_id, media_id, media_type) - except Exception as e: - logger.info("Failed to generate thumbnails: %s", e) - - @trace - async def create_content( + async def create_or_update_content( self, media_type: str, upload_name: Optional[str], content: IO, content_length: int, auth_user: UserID, + media_id: Optional[str] = None, ) -> MXCUri: - """Store uploaded content for a local user and return the mxc URL + """Create or update the content of the given media ID. Args: media_type: The content type of the file. @@ -349,16 +302,20 @@ class MediaRepository: content: A file like object that is the content to store content_length: The length of the content auth_user: The user_id of the uploader + media_id: The media ID to update if provided, otherwise creates + new media ID. Returns: The mxc url of the stored content """ - media_id = random_string(24) + is_new_media = media_id is None + if media_id is None: + media_id = random_string(24) file_info = FileInfo(server_name=None, file_id=media_id) - # This implements all of IO as it has a passthrough sha256reader = SHA256TransparentIOReader(content) + # This implements all of IO as it has a passthrough fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) sha256 = sha256reader.hexdigest() should_quarantine = await self.store.get_is_hash_quarantined(sha256) @@ -370,16 +327,27 @@ class MediaRepository: "Media has been automatically quarantined as it matched existing quarantined media" ) - await self.store.store_local_media( - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - sha256=sha256, - quarantined_by="system" if should_quarantine else None, - ) + if is_new_media: + await self.store.store_local_media( + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, + ) + else: + await self.store.update_local_media( + media_id=media_id, + media_type=media_type, + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, + ) try: await self._generate_thumbnails(None, media_id, media_id, media_type) diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 572f7897fd..74d8280582 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -120,7 +120,7 @@ class UploadServlet(BaseUploadServlet): try: content: IO = request.content # type: ignore - content_uri = await self.media_repo.create_content( + content_uri = await self.media_repo.create_or_update_content( media_type, upload_name, content, content_length, requester.user ) except SpamMediaException: @@ -170,13 +170,13 @@ class AsyncUploadServlet(BaseUploadServlet): try: content: IO = request.content # type: ignore - await self.media_repo.update_content( - media_id, + await self.media_repo.create_or_update_content( media_type, upload_name, content, content_length, requester.user, + media_id=media_id, ) except SpamMediaException: # For uploading of media we want to respond with a 400, instead of diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index 9c92003ce5..cd4905239f 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -67,7 +67,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): def test_file_download(self) -> None: content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "text/plain", "test_upload", content, @@ -110,7 +110,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): content = io.BytesIO(SMALL_PNG) content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "image/png", "test_png_upload", content, @@ -152,7 +152,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "text/plain", "test_upload", content, @@ -215,7 +215,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): def test_thumbnail_download_scaled(self) -> None: content = io.BytesIO(small_png.data) content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "image/png", "test_png_thumbnail", content, @@ -255,7 +255,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): def test_thumbnail_download_cropped(self) -> None: content = io.BytesIO(small_png.data) content_uri = self.get_success( - self.media_repo.create_content( + self.media_repo.create_or_update_content( "image/png", "test_png_thumbnail", content, diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py index d8f4f57c8c..89cf61430a 100644 --- a/tests/media/test_media_retention.py +++ b/tests/media/test_media_retention.py @@ -78,7 +78,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): # If the meda random_content = bytes(random_string(24), "utf-8") mxc_uri: MXCUri = self.get_success( - media_repository.create_content( + media_repository.create_or_update_content( media_type="text/plain", upload_name=None, content=io.BytesIO(random_content), diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 6ee761e44b..83aa6a280b 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -1952,7 +1952,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): def test_file_download(self) -> None: content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( - self.repo.create_content( + self.repo.create_or_update_content( "text/plain", "test_upload", content,