1
0
Fork 0

save metadata to S3

This commit is contained in:
Sean Sube 2024-01-14 20:03:16 -06:00
parent effa26c73b
commit 0f12c37931
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 28 additions and 11 deletions

View File

@ -1,10 +1,12 @@
from io import BytesIO from io import BytesIO
from json import dumps
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from boto3 import Session from boto3 import Session
from PIL import Image from PIL import Image
from ..output import make_output_names
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
@ -17,13 +19,12 @@ logger = getLogger(__name__)
class PersistS3Stage(BaseStage): class PersistS3Stage(BaseStage):
def run( def run(
self, self,
_worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: StageResult, sources: StageResult,
*, *,
output: List[str],
bucket: str, bucket: str,
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
@ -33,8 +34,8 @@ class PersistS3Stage(BaseStage):
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
# TODO: save metadata as well image_names = make_output_names(server, worker.job, len(sources))
for source, name in zip(sources.as_images(), output): for source, name in zip(sources.as_images(), image_names):
data = BytesIO() data = BytesIO()
source.save(data, format=server.image_format) source.save(data, format=server.image_format)
data.seek(0) data.seek(0)
@ -45,4 +46,18 @@ class PersistS3Stage(BaseStage):
except Exception: except Exception:
logger.exception("error saving image to S3") logger.exception("error saving image to S3")
metadata_names = make_output_names(
server, worker.job, len(sources), extension="json"
)
for metadata, name in zip(sources.metadata, metadata_names):
data = BytesIO()
data.write(dumps(metadata.tojson(server, [name])))
data.seek(0)
try:
s3.upload_fileobj(data, bucket, name)
logger.info("saved metadata to s3://%s/%s", bucket, name)
except Exception:
logger.exception("error saving metadata to S3")
return sources return sources

View File

@ -21,9 +21,11 @@ def make_output_names(
job_name: str, job_name: str,
count: int = 1, count: int = 1,
offset: int = 0, offset: int = 0,
extension: Optional[str] = None,
) -> List[str]: ) -> List[str]:
return [ return [
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset) f"{job_name}_{i}.{extension or server.image_format}"
for i in range(offset, count + offset)
] ]

View File

@ -446,7 +446,7 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH) result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
expected_paths = [ expected_paths = [
path.join("/path/to/cache", ONNX_MODEL), path.join(CACHE_PATH, ONNX_MODEL),
path.join("/path/to/cache/client1", ONNX_MODEL), path.join("/path/to/cache/client1", ONNX_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)
@ -460,7 +460,7 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
expected_paths = [ expected_paths = [
path.join("/path/to/cache", ONNX_MODEL), path.join(CACHE_PATH, ONNX_MODEL),
path.join("/path/to/cache/client2", ONNX_MODEL), path.join("/path/to/cache/client2", ONNX_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)
@ -475,7 +475,7 @@ class BuildCachePathsTests(unittest.TestCase):
) )
expected_paths = [ expected_paths = [
path.join("/path/to/cache", TORCH_MODEL), path.join(CACHE_PATH, TORCH_MODEL),
path.join("/path/to/cache/client3", TORCH_MODEL), path.join("/path/to/cache/client3", TORCH_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)
@ -489,7 +489,7 @@ class BuildCachePathsTests(unittest.TestCase):
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
expected_paths = [ expected_paths = [
path.join("/path/to/cache", ONNX_MODEL), path.join(CACHE_PATH, ONNX_MODEL),
path.join("/path/to/cache/client4", ONNX_MODEL), path.join("/path/to/cache/client4", ONNX_MODEL),
] ]
self.assertEqual(result, expected_paths) self.assertEqual(result, expected_paths)

View File

@ -1,6 +1,6 @@
/* eslint-disable max-lines */ /* eslint-disable max-lines */
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
import { create as batcher, keyResolver, windowScheduler, windowedFiniteBatchScheduler } from '@yornaath/batshit'; import { create as batcher, keyResolver, windowedFiniteBatchScheduler } from '@yornaath/batshit';
import { ServerParams } from '../config.js'; import { ServerParams } from '../config.js';
import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js'; import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js';