1
0

Merge create_content and update_content

This deduplicates a bunch of logic.
This commit is contained in:
Erik Johnston
2025-06-09 10:24:06 +01:00
parent 6521406a37
commit 447910df19
6 changed files with 41 additions and 73 deletions

View File

@@ -824,7 +824,7 @@ class SsoHandler:
return True
# store it in media repository
avatar_mxc_url = await self._media_repo.create_content(
avatar_mxc_url = await self._media_repo.create_or_update_content(
media_type=headers[b"Content-Type"][0].decode("utf-8"),
upload_name=upload_name,
content=picture,

View File

@@ -285,63 +285,16 @@ class MediaRepository:
raise NotFoundError("Media ID has expired")
@trace
async def update_content(
self,
media_id: str,
media_type: str,
upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: UserID,
) -> None:
"""Update the content of the given media ID.
Args:
media_id: The media ID to replace.
media_type: The content type of the file.
upload_name: The name of the file, if provided.
content: A file like object that is the content to store
content_length: The length of the content
auth_user: The user_id of the uploader
"""
file_info = FileInfo(server_name=None, file_id=media_id)
sha256reader = SHA256TransparentIOReader(content)
# This implements all of IO as it has a passthrough
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
logger.info("Stored local media in file %r", fname)
if should_quarantine:
logger.warn(
"Media has been automatically quarantined as it matched existing quarantined media"
)
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
try:
await self._generate_thumbnails(None, media_id, media_id, media_type)
except Exception as e:
logger.info("Failed to generate thumbnails: %s", e)
@trace
async def create_content(
async def create_or_update_content(
self,
media_type: str,
upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: UserID,
media_id: Optional[str] = None,
) -> MXCUri:
"""Store uploaded content for a local user and return the mxc URL
"""Create or update the content of the given media ID.
Args:
media_type: The content type of the file.
@@ -349,16 +302,20 @@ class MediaRepository:
content: A file like object that is the content to store
content_length: The length of the content
auth_user: The user_id of the uploader
media_id: The media ID to update if provided, otherwise creates
new media ID.
Returns:
The mxc url of the stored content
"""
media_id = random_string(24)
is_new_media = media_id is None
if media_id is None:
media_id = random_string(24)
file_info = FileInfo(server_name=None, file_id=media_id)
# This implements all of IO as it has a passthrough
sha256reader = SHA256TransparentIOReader(content)
# This implements all of IO as it has a passthrough
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
@@ -370,16 +327,27 @@ class MediaRepository:
"Media has been automatically quarantined as it matched existing quarantined media"
)
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
if is_new_media:
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
else:
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
try:
await self._generate_thumbnails(None, media_id, media_id, media_type)

View File

@@ -120,7 +120,7 @@ class UploadServlet(BaseUploadServlet):
try:
content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content(
content_uri = await self.media_repo.create_or_update_content(
media_type, upload_name, content, content_length, requester.user
)
except SpamMediaException:
@@ -170,13 +170,13 @@ class AsyncUploadServlet(BaseUploadServlet):
try:
content: IO = request.content # type: ignore
await self.media_repo.update_content(
media_id,
await self.media_repo.create_or_update_content(
media_type,
upload_name,
content,
content_length,
requester.user,
media_id=media_id,
)
except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of

View File

@@ -67,7 +67,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.media_repo.create_content(
self.media_repo.create_or_update_content(
"text/plain",
"test_upload",
content,
@@ -110,7 +110,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
content = io.BytesIO(SMALL_PNG)
content_uri = self.get_success(
self.media_repo.create_content(
self.media_repo.create_or_update_content(
"image/png",
"test_png_upload",
content,
@@ -152,7 +152,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.media_repo.create_content(
self.media_repo.create_or_update_content(
"text/plain",
"test_upload",
content,
@@ -215,7 +215,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
def test_thumbnail_download_scaled(self) -> None:
content = io.BytesIO(small_png.data)
content_uri = self.get_success(
self.media_repo.create_content(
self.media_repo.create_or_update_content(
"image/png",
"test_png_thumbnail",
content,
@@ -255,7 +255,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
def test_thumbnail_download_cropped(self) -> None:
content = io.BytesIO(small_png.data)
content_uri = self.get_success(
self.media_repo.create_content(
self.media_repo.create_or_update_content(
"image/png",
"test_png_thumbnail",
content,

View File

@@ -78,7 +78,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# If the meda
random_content = bytes(random_string(24), "utf-8")
mxc_uri: MXCUri = self.get_success(
media_repository.create_content(
media_repository.create_or_update_content(
media_type="text/plain",
upload_name=None,
content=io.BytesIO(random_content),

View File

@@ -1952,7 +1952,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.repo.create_content(
self.repo.create_or_update_content(
"text/plain",
"test_upload",
content,