Merge create_content and update_content
This deduplicates a bunch of logic.
This commit is contained in:
@@ -824,7 +824,7 @@ class SsoHandler:
|
||||
return True
|
||||
|
||||
# store it in media repository
|
||||
avatar_mxc_url = await self._media_repo.create_content(
|
||||
avatar_mxc_url = await self._media_repo.create_or_update_content(
|
||||
media_type=headers[b"Content-Type"][0].decode("utf-8"),
|
||||
upload_name=upload_name,
|
||||
content=picture,
|
||||
|
||||
@@ -285,63 +285,16 @@ class MediaRepository:
|
||||
raise NotFoundError("Media ID has expired")
|
||||
|
||||
@trace
|
||||
async def update_content(
|
||||
self,
|
||||
media_id: str,
|
||||
media_type: str,
|
||||
upload_name: Optional[str],
|
||||
content: IO,
|
||||
content_length: int,
|
||||
auth_user: UserID,
|
||||
) -> None:
|
||||
"""Update the content of the given media ID.
|
||||
|
||||
Args:
|
||||
media_id: The media ID to replace.
|
||||
media_type: The content type of the file.
|
||||
upload_name: The name of the file, if provided.
|
||||
content: A file like object that is the content to store
|
||||
content_length: The length of the content
|
||||
auth_user: The user_id of the uploader
|
||||
"""
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
sha256reader = SHA256TransparentIOReader(content)
|
||||
# This implements all of IO as it has a passthrough
|
||||
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
|
||||
sha256 = sha256reader.hexdigest()
|
||||
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
if should_quarantine:
|
||||
logger.warn(
|
||||
"Media has been automatically quarantined as it matched existing quarantined media"
|
||||
)
|
||||
|
||||
await self.store.update_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||
except Exception as e:
|
||||
logger.info("Failed to generate thumbnails: %s", e)
|
||||
|
||||
@trace
|
||||
async def create_content(
|
||||
async def create_or_update_content(
|
||||
self,
|
||||
media_type: str,
|
||||
upload_name: Optional[str],
|
||||
content: IO,
|
||||
content_length: int,
|
||||
auth_user: UserID,
|
||||
media_id: Optional[str] = None,
|
||||
) -> MXCUri:
|
||||
"""Store uploaded content for a local user and return the mxc URL
|
||||
"""Create or update the content of the given media ID.
|
||||
|
||||
Args:
|
||||
media_type: The content type of the file.
|
||||
@@ -349,16 +302,20 @@ class MediaRepository:
|
||||
content: A file like object that is the content to store
|
||||
content_length: The length of the content
|
||||
auth_user: The user_id of the uploader
|
||||
media_id: The media ID to update if provided, otherwise creates
|
||||
new media ID.
|
||||
|
||||
Returns:
|
||||
The mxc url of the stored content
|
||||
"""
|
||||
|
||||
media_id = random_string(24)
|
||||
is_new_media = media_id is None
|
||||
if media_id is None:
|
||||
media_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
# This implements all of IO as it has a passthrough
|
||||
sha256reader = SHA256TransparentIOReader(content)
|
||||
# This implements all of IO as it has a passthrough
|
||||
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
|
||||
sha256 = sha256reader.hexdigest()
|
||||
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
|
||||
@@ -370,16 +327,27 @@ class MediaRepository:
|
||||
"Media has been automatically quarantined as it matched existing quarantined media"
|
||||
)
|
||||
|
||||
await self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
if is_new_media:
|
||||
await self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
else:
|
||||
await self.store.update_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||
|
||||
@@ -120,7 +120,7 @@ class UploadServlet(BaseUploadServlet):
|
||||
|
||||
try:
|
||||
content: IO = request.content # type: ignore
|
||||
content_uri = await self.media_repo.create_content(
|
||||
content_uri = await self.media_repo.create_or_update_content(
|
||||
media_type, upload_name, content, content_length, requester.user
|
||||
)
|
||||
except SpamMediaException:
|
||||
@@ -170,13 +170,13 @@ class AsyncUploadServlet(BaseUploadServlet):
|
||||
|
||||
try:
|
||||
content: IO = request.content # type: ignore
|
||||
await self.media_repo.update_content(
|
||||
media_id,
|
||||
await self.media_repo.create_or_update_content(
|
||||
media_type,
|
||||
upload_name,
|
||||
content,
|
||||
content_length,
|
||||
requester.user,
|
||||
media_id=media_id,
|
||||
)
|
||||
except SpamMediaException:
|
||||
# For uploading of media we want to respond with a 400, instead of
|
||||
|
||||
@@ -67,7 +67,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
def test_file_download(self) -> None:
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
@@ -110,7 +110,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
content = io.BytesIO(SMALL_PNG)
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"image/png",
|
||||
"test_png_upload",
|
||||
content,
|
||||
@@ -152,7 +152,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
@@ -215,7 +215,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
|
||||
def test_thumbnail_download_scaled(self) -> None:
|
||||
content = io.BytesIO(small_png.data)
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"image/png",
|
||||
"test_png_thumbnail",
|
||||
content,
|
||||
@@ -255,7 +255,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
|
||||
def test_thumbnail_download_cropped(self) -> None:
|
||||
content = io.BytesIO(small_png.data)
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"image/png",
|
||||
"test_png_thumbnail",
|
||||
content,
|
||||
|
||||
@@ -78,7 +78,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
|
||||
# If the meda
|
||||
random_content = bytes(random_string(24), "utf-8")
|
||||
mxc_uri: MXCUri = self.get_success(
|
||||
media_repository.create_content(
|
||||
media_repository.create_or_update_content(
|
||||
media_type="text/plain",
|
||||
upload_name=None,
|
||||
content=io.BytesIO(random_content),
|
||||
|
||||
@@ -1952,7 +1952,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
|
||||
def test_file_download(self) -> None:
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.repo.create_content(
|
||||
self.repo.create_or_update_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
|
||||
Reference in New Issue
Block a user