feat(api): server setting to select image format
This commit is contained in:
parent
83992d9193
commit
6ce13096ed
|
@ -17,7 +17,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_s3(
|
def persist_s3(
|
||||||
_ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -32,7 +32,7 @@ def persist_s3(
|
||||||
s3 = session.client('s3', endpoint_url=endpoint_url)
|
s3 = session.client('s3', endpoint_url=endpoint_url)
|
||||||
|
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
source_image.save(data, format='png')
|
source_image.save(data, format=ctx.image_format)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -43,8 +43,8 @@ def json_params(
|
||||||
|
|
||||||
|
|
||||||
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
|
||||||
path = base_join(ctx.output_path, output)
|
path = base_join(ctx.output_path, '%s.%s' % (output, ctx.image_format))
|
||||||
image.save(path)
|
image.save(path, format=ctx.image_format)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ class ServerContext:
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
block_platforms: List[str] = [],
|
block_platforms: List[str] = [],
|
||||||
default_platform: str = None,
|
default_platform: str = None,
|
||||||
|
image_format: str = 'png',
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -38,6 +39,7 @@ class ServerContext:
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.block_platforms = block_platforms
|
self.block_platforms = block_platforms
|
||||||
self.default_platform = default_platform
|
self.default_platform = default_platform
|
||||||
|
self.image_format = image_format
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
|
@ -56,6 +58,9 @@ class ServerContext:
|
||||||
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
|
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
|
||||||
default_platform=environ.get(
|
default_platform=environ.get(
|
||||||
'ONNX_WEB_DEFAULT_PLATFORM', None),
|
'ONNX_WEB_DEFAULT_PLATFORM', None),
|
||||||
|
image_format=environ.get(
|
||||||
|
'ONNX_WEB_IMAGE_FORMAT', 'png'
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,7 +160,7 @@ def make_output_name(
|
||||||
for param in extras:
|
for param in extras:
|
||||||
hash_value(sha, param)
|
hash_value(sha, param)
|
||||||
|
|
||||||
return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
|
return '%s_%s_%s_%s' % (mode, params.seed, sha.hexdigest(), now)
|
||||||
|
|
||||||
|
|
||||||
def base_join(base: str, tail: str) -> str:
|
def base_join(base: str, tail: str) -> str:
|
||||||
|
|
Loading…
Reference in New Issue