save metadata to S3
This commit is contained in:
parent
effa26c73b
commit
0f12c37931
|
@ -1,10 +1,12 @@
|
|||
from io import BytesIO
|
||||
from json import dumps
|
||||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from boto3 import Session
|
||||
from PIL import Image
|
||||
|
||||
from ..output import make_output_names
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
|
@ -17,13 +19,12 @@ logger = getLogger(__name__)
|
|||
class PersistS3Stage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
output: List[str],
|
||||
bucket: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
|
@ -33,8 +34,8 @@ class PersistS3Stage(BaseStage):
|
|||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
# TODO: save metadata as well
|
||||
for source, name in zip(sources.as_images(), output):
|
||||
image_names = make_output_names(server, worker.job, len(sources))
|
||||
for source, name in zip(sources.as_images(), image_names):
|
||||
data = BytesIO()
|
||||
source.save(data, format=server.image_format)
|
||||
data.seek(0)
|
||||
|
@ -45,4 +46,18 @@ class PersistS3Stage(BaseStage):
|
|||
except Exception:
|
||||
logger.exception("error saving image to S3")
|
||||
|
||||
metadata_names = make_output_names(
|
||||
server, worker.job, len(sources), extension="json"
|
||||
)
|
||||
for metadata, name in zip(sources.metadata, metadata_names):
|
||||
data = BytesIO()
|
||||
data.write(dumps(metadata.tojson(server, [name])))
|
||||
data.seek(0)
|
||||
|
||||
try:
|
||||
s3.upload_fileobj(data, bucket, name)
|
||||
logger.info("saved metadata to s3://%s/%s", bucket, name)
|
||||
except Exception:
|
||||
logger.exception("error saving metadata to S3")
|
||||
|
||||
return sources
|
||||
|
|
|
@ -21,9 +21,11 @@ def make_output_names(
|
|||
job_name: str,
|
||||
count: int = 1,
|
||||
offset: int = 0,
|
||||
extension: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
return [
|
||||
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
|
||||
f"{job_name}_{i}.{extension or server.image_format}"
|
||||
for i in range(offset, count + offset)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -446,7 +446,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
path.join(CACHE_PATH, ONNX_MODEL),
|
||||
path.join("/path/to/cache/client1", ONNX_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
@ -460,7 +460,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
path.join(CACHE_PATH, ONNX_MODEL),
|
||||
path.join("/path/to/cache/client2", ONNX_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
@ -475,7 +475,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", TORCH_MODEL),
|
||||
path.join(CACHE_PATH, TORCH_MODEL),
|
||||
path.join("/path/to/cache/client3", TORCH_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
@ -489,7 +489,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
|||
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", ONNX_MODEL),
|
||||
path.join(CACHE_PATH, ONNX_MODEL),
|
||||
path.join("/path/to/cache/client4", ONNX_MODEL),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/* eslint-disable max-lines */
|
||||
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
||||
import { create as batcher, keyResolver, windowScheduler, windowedFiniteBatchScheduler } from '@yornaath/batshit';
|
||||
import { create as batcher, keyResolver, windowedFiniteBatchScheduler } from '@yornaath/batshit';
|
||||
|
||||
import { ServerParams } from '../config.js';
|
||||
import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js';
|
||||
|
|
Loading…
Reference in New Issue