From 042181b9c522dadd75310ba888c5e748ecd778b0 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Nov 2023 21:45:32 -0600 Subject: [PATCH] fix(api): correctly handle image stacks in persist stages --- api/onnx_web/chain/persist_disk.py | 2 +- api/onnx_web/chain/persist_s3.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 7a2007ce..28a08848 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -31,7 +31,7 @@ class PersistDiskStage(BaseStage): ) -> StageResult: logger.info("persisting %s images to disk: %s", len(sources), output) - for source, name in zip(sources, output): + for source, name in zip(sources.as_image(), output): dest = save_image(server, name, source, params=params, size=size) logger.info("saved image to %s", dest) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 6bd03f72..060afc4f 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -1,6 +1,6 @@ from io import BytesIO from logging import getLogger -from typing import Optional +from typing import List, Optional from boto3 import Session from PIL import Image @@ -23,7 +23,7 @@ class PersistS3Stage(BaseStage): _params: ImageParams, sources: StageResult, *, - output: str, + output: List[str], bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, @@ -33,14 +33,14 @@ class PersistS3Stage(BaseStage): session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) - for source in sources.as_image(): + for source, name in zip(sources.as_image(), output): data = BytesIO() source.save(data, format=server.image_format) data.seek(0) try: - s3.upload_fileobj(data, bucket, output) - logger.info("saved image to s3://%s/%s", bucket, output) + s3.upload_fileobj(data, bucket, name) + logger.info("saved image to s3://%s/%s", bucket, name) except Exception: logger.exception("error saving image to S3")