save metadata to S3
This commit is contained in:
parent
effa26c73b
commit
0f12c37931
|
@ -1,10 +1,12 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from json import dumps
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from boto3 import Session
|
from boto3 import Session
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..output import make_output_names
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
|
@ -17,13 +19,12 @@ logger = getLogger(__name__)
|
||||||
class PersistS3Stage(BaseStage):
|
class PersistS3Stage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
output: List[str],
|
|
||||||
bucket: str,
|
bucket: str,
|
||||||
endpoint_url: Optional[str] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
profile_name: Optional[str] = None,
|
profile_name: Optional[str] = None,
|
||||||
|
@ -33,8 +34,8 @@ class PersistS3Stage(BaseStage):
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
# TODO: save metadata as well
|
image_names = make_output_names(server, worker.job, len(sources))
|
||||||
for source, name in zip(sources.as_images(), output):
|
for source, name in zip(sources.as_images(), image_names):
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
source.save(data, format=server.image_format)
|
source.save(data, format=server.image_format)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
|
@ -45,4 +46,18 @@ class PersistS3Stage(BaseStage):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error saving image to S3")
|
logger.exception("error saving image to S3")
|
||||||
|
|
||||||
|
metadata_names = make_output_names(
|
||||||
|
server, worker.job, len(sources), extension="json"
|
||||||
|
)
|
||||||
|
for metadata, name in zip(sources.metadata, metadata_names):
|
||||||
|
data = BytesIO()
|
||||||
|
data.write(dumps(metadata.tojson(server, [name])))
|
||||||
|
data.seek(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
s3.upload_fileobj(data, bucket, name)
|
||||||
|
logger.info("saved metadata to s3://%s/%s", bucket, name)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("error saving metadata to S3")
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
|
@ -21,9 +21,11 @@ def make_output_names(
|
||||||
job_name: str,
|
job_name: str,
|
||||||
count: int = 1,
|
count: int = 1,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
|
extension: Optional[str] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
return [
|
return [
|
||||||
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
|
f"{job_name}_{i}.{extension or server.image_format}"
|
||||||
|
for i in range(offset, count + offset)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -446,7 +446,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
||||||
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
|
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", ONNX_MODEL),
|
path.join(CACHE_PATH, ONNX_MODEL),
|
||||||
path.join("/path/to/cache/client1", ONNX_MODEL),
|
path.join("/path/to/cache/client1", ONNX_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
@ -460,7 +460,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
||||||
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", ONNX_MODEL),
|
path.join(CACHE_PATH, ONNX_MODEL),
|
||||||
path.join("/path/to/cache/client2", ONNX_MODEL),
|
path.join("/path/to/cache/client2", ONNX_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
@ -475,7 +475,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", TORCH_MODEL),
|
path.join(CACHE_PATH, TORCH_MODEL),
|
||||||
path.join("/path/to/cache/client3", TORCH_MODEL),
|
path.join("/path/to/cache/client3", TORCH_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
@ -489,7 +489,7 @@ class BuildCachePathsTests(unittest.TestCase):
|
||||||
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
||||||
|
|
||||||
expected_paths = [
|
expected_paths = [
|
||||||
path.join("/path/to/cache", ONNX_MODEL),
|
path.join(CACHE_PATH, ONNX_MODEL),
|
||||||
path.join("/path/to/cache/client4", ONNX_MODEL),
|
path.join("/path/to/cache/client4", ONNX_MODEL),
|
||||||
]
|
]
|
||||||
self.assertEqual(result, expected_paths)
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/* eslint-disable max-lines */
|
/* eslint-disable max-lines */
|
||||||
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
||||||
import { create as batcher, keyResolver, windowScheduler, windowedFiniteBatchScheduler } from '@yornaath/batshit';
|
import { create as batcher, keyResolver, windowedFiniteBatchScheduler } from '@yornaath/batshit';
|
||||||
|
|
||||||
import { ServerParams } from '../config.js';
|
import { ServerParams } from '../config.js';
|
||||||
import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js';
|
import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js';
|
||||||
|
|
Loading…
Reference in New Issue