From 0f12c379310aa66e9b17c1e87bf7165d573be871 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 14 Jan 2024 20:03:16 -0600 Subject: [PATCH] save metadata to S3 --- api/onnx_web/chain/persist_s3.py | 25 ++++++++++++++++++++----- api/onnx_web/output.py | 4 +++- api/tests/convert/test_utils.py | 8 ++++---- gui/src/client/api.ts | 2 +- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 6c4cc5f3..4e118bb1 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -1,10 +1,12 @@ from io import BytesIO +from json import dumps from logging import getLogger -from typing import List, Optional +from typing import Optional from boto3 import Session from PIL import Image +from ..output import make_output_names from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext @@ -17,13 +19,12 @@ logger = getLogger(__name__) class PersistS3Stage(BaseStage): def run( self, - _worker: WorkerContext, + worker: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, sources: StageResult, *, - output: List[str], bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, @@ -33,8 +34,8 @@ class PersistS3Stage(BaseStage): session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) - # TODO: save metadata as well - for source, name in zip(sources.as_images(), output): + image_names = make_output_names(server, worker.job, len(sources)) + for source, name in zip(sources.as_images(), image_names): data = BytesIO() source.save(data, format=server.image_format) data.seek(0) @@ -45,4 +46,18 @@ class PersistS3Stage(BaseStage): except Exception: logger.exception("error saving image to S3") + metadata_names = make_output_names( + server, worker.job, len(sources), extension="json" + ) + for metadata, name in zip(sources.metadata, metadata_names): + data = BytesIO() + data.write(dumps(metadata.tojson(server, [name]))) + data.seek(0) + + try: + s3.upload_fileobj(data, bucket, name) + logger.info("saved metadata to s3://%s/%s", bucket, name) + except Exception: + logger.exception("error saving metadata to S3") + return sources diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 43d29787..8146d546 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -21,9 +21,11 @@ def make_output_names( job_name: str, count: int = 1, offset: int = 0, + extension: Optional[str] = None, ) -> List[str]: return [ - f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset) + f"{job_name}_{i}.{extension or server.image_format}" + for i in range(offset, count + offset) ] diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index b48e8672..1a314f25 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -446,7 +446,7 @@ class BuildCachePathsTests(unittest.TestCase): result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH) expected_paths = [ - path.join("/path/to/cache", ONNX_MODEL), + path.join(CACHE_PATH, ONNX_MODEL), path.join("/path/to/cache/client1", ONNX_MODEL), ] self.assertEqual(result, expected_paths) @@ -460,7 +460,7 @@ class BuildCachePathsTests(unittest.TestCase): result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) expected_paths = [ - path.join("/path/to/cache", ONNX_MODEL), + path.join(CACHE_PATH, ONNX_MODEL), path.join("/path/to/cache/client2", ONNX_MODEL), ] self.assertEqual(result, expected_paths) @@ -475,7 +475,7 @@ class BuildCachePathsTests(unittest.TestCase): ) expected_paths = [ - path.join("/path/to/cache", TORCH_MODEL), + path.join(CACHE_PATH, TORCH_MODEL), path.join("/path/to/cache/client3", TORCH_MODEL), ] self.assertEqual(result, expected_paths) @@ -489,7 +489,7 @@ class BuildCachePathsTests(unittest.TestCase): result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) expected_paths = [ - path.join("/path/to/cache", ONNX_MODEL), + path.join(CACHE_PATH, ONNX_MODEL), path.join("/path/to/cache/client4", ONNX_MODEL), ] self.assertEqual(result, expected_paths) diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 1e00ea43..c1297df0 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -1,6 +1,6 @@ /* eslint-disable max-lines */ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; -import { create as batcher, keyResolver, windowScheduler, windowedFiniteBatchScheduler } from '@yornaath/batshit'; +import { create as batcher, keyResolver, windowedFiniteBatchScheduler } from '@yornaath/batshit'; import { ServerParams } from '../config.js'; import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js';