1
0
Fork 0

save metadata to S3

This commit is contained in:
Sean Sube 2024-01-14 20:03:16 -06:00
parent effa26c73b
commit 0f12c37931
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 28 additions and 11 deletions

View File

@ -1,10 +1,12 @@
from io import BytesIO
from json import dumps
from logging import getLogger
from typing import List, Optional
from typing import Optional
from boto3 import Session
from PIL import Image
from ..output import make_output_names
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
@ -17,13 +19,12 @@ logger = getLogger(__name__)
class PersistS3Stage(BaseStage):
def run(
self,
_worker: WorkerContext,
worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
output: List[str],
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
@ -33,8 +34,8 @@ class PersistS3Stage(BaseStage):
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
# TODO: save metadata as well
for source, name in zip(sources.as_images(), output):
image_names = make_output_names(server, worker.job, len(sources))
for source, name in zip(sources.as_images(), image_names):
data = BytesIO()
source.save(data, format=server.image_format)
data.seek(0)
@ -45,4 +46,18 @@ class PersistS3Stage(BaseStage):
except Exception:
logger.exception("error saving image to S3")
metadata_names = make_output_names(
server, worker.job, len(sources), extension="json"
)
for metadata, name in zip(sources.metadata, metadata_names):
data = BytesIO()
data.write(dumps(metadata.tojson(server, [name])))
data.seek(0)
try:
s3.upload_fileobj(data, bucket, name)
logger.info("saved metadata to s3://%s/%s", bucket, name)
except Exception:
logger.exception("error saving metadata to S3")
return sources

View File

@ -21,9 +21,11 @@ def make_output_names(
job_name: str,
count: int = 1,
offset: int = 0,
extension: Optional[str] = None,
) -> List[str]:
return [
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
f"{job_name}_{i}.{extension or server.image_format}"
for i in range(offset, count + offset)
]

View File

@ -446,7 +446,7 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
expected_paths = [
path.join("/path/to/cache", ONNX_MODEL),
path.join(CACHE_PATH, ONNX_MODEL),
path.join("/path/to/cache/client1", ONNX_MODEL),
]
self.assertEqual(result, expected_paths)
@ -460,7 +460,7 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
expected_paths = [
path.join("/path/to/cache", ONNX_MODEL),
path.join(CACHE_PATH, ONNX_MODEL),
path.join("/path/to/cache/client2", ONNX_MODEL),
]
self.assertEqual(result, expected_paths)
@ -475,7 +475,7 @@ class BuildCachePathsTests(unittest.TestCase):
)
expected_paths = [
path.join("/path/to/cache", TORCH_MODEL),
path.join(CACHE_PATH, TORCH_MODEL),
path.join("/path/to/cache/client3", TORCH_MODEL),
]
self.assertEqual(result, expected_paths)
@ -489,7 +489,7 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
expected_paths = [
path.join("/path/to/cache", ONNX_MODEL),
path.join(CACHE_PATH, ONNX_MODEL),
path.join("/path/to/cache/client4", ONNX_MODEL),
]
self.assertEqual(result, expected_paths)

View File

@ -1,6 +1,6 @@
/* eslint-disable max-lines */
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
import { create as batcher, keyResolver, windowScheduler, windowedFiniteBatchScheduler } from '@yornaath/batshit';
import { create as batcher, keyResolver, windowedFiniteBatchScheduler } from '@yornaath/batshit';
import { ServerParams } from '../config.js';
import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js';