From 6ce13096edba7349b46d9e1727058e2694fe1035 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 2 Feb 2023 08:19:57 -0600 Subject: [PATCH] feat(api): server setting to select image format --- api/onnx_web/chain/persist_s3.py | 4 ++-- api/onnx_web/output.py | 4 ++-- api/onnx_web/utils.py | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 2653a7f1..d54f4a93 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -17,7 +17,7 @@ logger = getLogger(__name__) def persist_s3( - _ctx: ServerContext, + ctx: ServerContext, _stage: StageParams, _params: ImageParams, source_image: Image.Image, @@ -32,7 +32,7 @@ def persist_s3( s3 = session.client('s3', endpoint_url=endpoint_url) data = BytesIO() - source_image.save(data, format='png') + source_image.save(data, format=ctx.image_format) data.seek(0) try: diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 4b3291af..11b43f87 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -43,8 +43,8 @@ def json_params( def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str: - path = base_join(ctx.output_path, output) - image.save(path) + path = base_join(ctx.output_path, '%s.%s' % (output, ctx.image_format)) + image.save(path, format=ctx.image_format) return path diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 2090b085..cf5b4953 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -29,6 +29,7 @@ class ServerContext: num_workers: int = 1, block_platforms: List[str] = [], default_platform: str = None, + image_format: str = 'png', ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -38,6 +39,7 @@ class ServerContext: self.num_workers = num_workers self.block_platforms = block_platforms self.default_platform = default_platform + self.image_format = image_format @classmethod def from_environ(cls): @@ -56,6 +58,9 @@ class ServerContext: 'ONNX_WEB_BLOCK_PLATFORMS', '').split(','), default_platform=environ.get( 'ONNX_WEB_DEFAULT_PLATFORM', None), + image_format=environ.get( + 'ONNX_WEB_IMAGE_FORMAT', 'png' + ), ) @@ -155,7 +160,7 @@ def make_output_name( for param in extras: hash_value(sha, param) - return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now) + return '%s_%s_%s_%s' % (mode, params.seed, sha.hexdigest(), now) def base_join(base: str, tail: str) -> str: