diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index fca46134..9173d168 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -16,18 +16,12 @@ from .utils import base_join, hash_value logger = getLogger(__name__) -def make_output_name( +def make_output_names( server: ServerContext, - mode: str, - params: ImageParams, - size: Size, - extras: Optional[List[Optional[Param]]] = None, - count: Optional[int] = None, + job_name: str, + count: int = 1, offset: int = 0, ) -> List[str]: - count = count or params.batch - job_name = make_job_name(mode, params, size, extras) - return [ f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset) ] @@ -68,12 +62,14 @@ def save_result( result: StageResult, base_name: str, ) -> List[str]: + images = result.as_image() + outputs = make_output_names(server, base_name, len(images)) results = [] - for i, (image, metadata) in enumerate(zip(result.as_image(), result.metadata)): + for image, metadata, filename in zip(images, result.metadata, outputs): results.append( save_image( server, - base_name + f"_{i}.{server.image_format}", + filename, image, metadata, ) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 9eb86c4b..58c904aa 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -18,7 +18,7 @@ from ..diffusers.run import ( run_upscale_pipeline, ) from ..diffusers.utils import replace_wildcards -from ..output import make_job_name +from ..output import make_job_name, make_output_names from ..params import Progress, Size, StageParams, TileOrder from ..transformers.run import run_txt2txt_pipeline from ..utils import ( @@ -668,7 +668,10 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): # TODO: accumulate results if progress is not None: - # TODO: add output paths based on progress.results counter + outputs = None + if progress.results > 0: + outputs = make_output_names(server, job_name, progress.results) + return image_reply( job_name, status, @@ -676,6 +679,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): stages=Progress(progress.stages, 0), steps=Progress(progress.steps, 0), tiles=Progress(progress.tiles, 0), + outputs=outputs, ) return image_reply(job_name, status, "TODO")