From feb9ea9e1a022a751327e497dedf5b2381503dc7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Sep 2023 23:06:00 -0500 Subject: [PATCH] fix order, parse right data, use output names --- api/onnx_web/chain/blend_grid.py | 9 +++++---- api/onnx_web/chain/persist_disk.py | 11 +++++------ api/onnx_web/chain/source_txt2img.py | 2 +- api/onnx_web/server/api.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 51472a42..34a45753 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -37,11 +37,12 @@ class BlendGridStage(BaseStage): output = Image.new("RGB", (size[0] * width, size[1] * height)) # TODO: labels - for i in order or range(len(sources)): + order = order or range(len(sources)) + for i in len(order): x = i % width - y = i / width + y = i // width - output.paste(sources[i], (x * size[0], y * size[1])) + n = order[i] + output.paste(sources[n], (x * size[0], y * size[1])) return [output] - diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 4ca200d1..fb9e53c0 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import List +from typing import List, Optional from PIL import Image @@ -21,13 +21,12 @@ class PersistDiskStage(BaseStage): params: ImageParams, sources: List[Image.Image], *, - outputs: List[str], - stage_source: Image.Image, + output: List[str], + stage_source: Optional[Image.Image] = None, **kwargs, ) -> List[Image.Image]: - for source, output in zip(sources, outputs): - # TODO: append index to output name - dest = save_image(server, output, source, params=params) + for source, name in zip(sources, output): + dest = save_image(server, name, source, params=params) logger.info("saved image to %s", dest) return sources diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index c69443a1..cda81b3d 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -32,7 +32,7 @@ class SourceTxt2ImgStage(BaseStage): params: ImageParams, sources: List[Image.Image], *, - dims: Tuple[int, int, int], + dims: Tuple[int, int, int] = None, size: Size, callback: Optional[ProgressCallback] = None, latents: Optional[np.ndarray] = None, diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 3d30d2e0..85103bba 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -387,7 +387,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): validate(data, schema) # get defaults from the regular parameters - device, _params, _size = pipeline_from_request(server) + device, _params, _size = pipeline_from_request(server, data=data) pipeline = ChainPipeline() for stage_data in data.get("stages", []): stage_class = CHAIN_STAGES[stage_data.get("type")]