From 5f3b84827b3e1db0585f0ef5dd11ad8cbfcbc4b2 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 20 Feb 2023 08:35:18 -0600 Subject: [PATCH] feat(api): add batch size to txt2img and img2img pipelines (#195) --- api/onnx_web/chain/persist_disk.py | 5 +- api/onnx_web/diffusion/run.py | 78 +++++++++++---------- api/onnx_web/output.py | 17 ++--- api/onnx_web/params.py | 4 ++ api/params.json | 6 ++ gui/src/client.ts | 1 + gui/src/components/control/ImageControl.tsx | 15 ++++ gui/src/state.ts | 1 + 8 files changed, 77 insertions(+), 50 deletions(-) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 9a5f0cd0..041dabae 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,4 +1,5 @@ from logging import getLogger +from typing import List from PIL import Image @@ -16,12 +17,12 @@ def persist_disk( _params: ImageParams, source: Image.Image, *, - output: str, + output: List[str], stage_source: Image.Image, **kwargs, ) -> Image.Image: source = stage_source or source - dest = save_image(server, output, source) + dest = save_image(server, output[0], source) logger.info("saved image to %s", dest) return source diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 5393fa66..a5417e84 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -25,7 +25,7 @@ def run_txt2img_pipeline( server: ServerContext, params: ImageParams, size: Size, - output: str, + outputs: List[str], upscale: UpscaleParams, ) -> None: latents = get_latents_from_seed(params.seed, size) @@ -50,6 +50,7 @@ def run_txt2img_pipeline( guidance_scale=params.cfg, latents=latents, negative_prompt=params.negative_prompt, + num_images_per_prompt=params.batch, num_inference_steps=params.steps, eta=params.eta, callback=progress, @@ -64,27 +65,27 @@ def run_txt2img_pipeline( guidance_scale=params.cfg, latents=latents, negative_prompt=params.negative_prompt, + num_images_per_prompt=params.batch, num_inference_steps=params.steps, eta=params.eta, callback=progress, ) - image = result.images[0] - image = run_upscale_correction( - job, - server, - StageParams(), - params, - image, - upscale=upscale, - callback=progress, - ) + for image, output in zip(result.images, outputs): + image = run_upscale_correction( + job, + server, + StageParams(), + params, + image, + upscale=upscale, + callback=progress, + ) - dest = save_image(server, output, image) - save_params(server, output, params, size, upscale=upscale) + dest = save_image(server, output, image) + save_params(server, output, params, size, upscale=upscale) del pipe - del image del result run_gc([job.get_device()]) @@ -96,7 +97,7 @@ def run_img2img_pipeline( job: JobContext, server: ServerContext, params: ImageParams, - output: str, + outputs: List[str], upscale: UpscaleParams, source: Image.Image, strength: float, @@ -119,6 +120,7 @@ def run_img2img_pipeline( generator=rng, guidance_scale=params.cfg, negative_prompt=params.negative_prompt, + num_images_per_prompt=params.batch, num_inference_steps=params.steps, strength=strength, eta=params.eta, @@ -132,29 +134,29 @@ def run_img2img_pipeline( generator=rng, guidance_scale=params.cfg, negative_prompt=params.negative_prompt, + num_images_per_prompt=params.batch, num_inference_steps=params.steps, strength=strength, eta=params.eta, callback=progress, ) - image = result.images[0] - image = run_upscale_correction( - job, - server, - StageParams(), - params, - image, - upscale=upscale, - callback=progress, - ) + for image, output in zip(result.images, outputs): + image = run_upscale_correction( + job, + server, + StageParams(), + params, + image, + upscale=upscale, + callback=progress, + ) - dest = save_image(server, output, image) - size = Size(*source.size) - save_params(server, output, params, size, upscale=upscale) + dest = save_image(server, output, image) + size = Size(*source.size) + save_params(server, output, params, size, upscale=upscale) del pipe - del image del result run_gc([job.get_device()]) @@ -167,7 +169,7 @@ def run_inpaint_pipeline( server: ServerContext, params: ImageParams, size: Size, - output: str, + outputs: List[str], upscale: UpscaleParams, source: Image.Image, mask: Image.Image, @@ -202,8 +204,8 @@ def run_inpaint_pipeline( job, server, stage, params, image, upscale=upscale, callback=progress ) - dest = save_image(server, output, image) - save_params(server, output, params, size, upscale=upscale, border=border) + dest = save_image(server, outputs[0], image) + save_params(server, outputs[0], params, size, upscale=upscale, border=border) del image @@ -217,7 +219,7 @@ def run_upscale_pipeline( server: ServerContext, params: ImageParams, size: Size, - output: str, + outputs: List[str], upscale: UpscaleParams, source: Image.Image, ) -> None: @@ -228,8 +230,8 @@ def run_upscale_pipeline( job, server, stage, params, source, upscale=upscale, callback=progress ) - dest = save_image(server, output, image) - save_params(server, output, params, size, upscale=upscale) + dest = save_image(server, outputs[0], image) + save_params(server, outputs[0], params, size, upscale=upscale) del image @@ -243,7 +245,7 @@ def run_blend_pipeline( server: ServerContext, params: ImageParams, size: Size, - output: str, + outputs: List[str], upscale: UpscaleParams, sources: List[Image.Image], mask: Image.Image, @@ -266,8 +268,8 @@ def run_blend_pipeline( job, server, stage, params, image, upscale=upscale, callback=progress ) - dest = save_image(server, output, image) - save_params(server, output, params, size, upscale=upscale) + dest = save_image(server, outputs[0], image) + save_params(server, outputs[0], params, size, upscale=upscale) del image diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index abcf5c33..330ca14d 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -4,7 +4,7 @@ from logging import getLogger from os import path from struct import pack from time import time -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple from PIL import Image @@ -63,7 +63,7 @@ def make_output_name( params: ImageParams, size: Size, extras: Optional[Tuple[Param]] = None, -) -> str: +) -> List[str]: now = int(time()) sha = sha256() @@ -82,13 +82,10 @@ def make_output_name( for param in extras: hash_value(sha, param) - return "%s_%s_%s_%s.%s" % ( - mode, - params.seed, - sha.hexdigest(), - now, - ctx.image_format, - ) + return [ + f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}" + for i in range(params.batch) + ] def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str: @@ -106,7 +103,7 @@ def save_params( upscale: Optional[UpscaleParams] = None, border: Optional[Border] = None, ) -> str: - path = base_join(ctx.output_path, "%s.json" % (output)) + path = base_join(ctx.output_path, f"{output}.json") json = json_params(output, params, size, upscale=upscale, border=border) with open(path, "w") as f: f.write(dumps(json)) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 16528000..fada5d23 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -156,6 +156,7 @@ class ImageParams: negative_prompt: Optional[str] = None, lpw: bool = False, eta: float = 0.0, + batch: int = 1, ) -> None: self.model = model self.scheduler = scheduler @@ -166,6 +167,7 @@ class ImageParams: self.steps = steps self.lpw = lpw or False self.eta = eta + self.batch = batch def tojson(self) -> Dict[str, Optional[Param]]: return { @@ -178,6 +180,7 @@ class ImageParams: "steps": self.steps, "lpw": self.lpw, "eta": self.eta, + "batch": self.batch, } def with_args(self, **kwargs): @@ -191,6 +194,7 @@ class ImageParams: kwargs.get("negative_prompt", self.negative_prompt), kwargs.get("lpw", self.lpw), kwargs.get("eta", self.eta), + kwargs.get("batch", self.batch), ) diff --git a/api/params.json b/api/params.json index 7dda2d2c..599634a0 100644 --- a/api/params.json +++ b/api/params.json @@ -1,5 +1,11 @@ { "version": "0.7.1", + "batch": { + "default": 1, + "min": 1, + "max": 5, + "step": 1 + }, "bottom": { "default": 0, "min": 0, diff --git a/gui/src/client.ts b/gui/src/client.ts index 940dbcc3..6d094552 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -42,6 +42,7 @@ export interface BaseImgParams { prompt: string; negativePrompt?: string; + batch: number; cfg: number; steps: number; seed: number; diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index 6b80e552..fdeb26b8 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -68,6 +68,21 @@ export function ImageControl(props: ImageControlProps) { } }} /> + { + if (doesExist(props.onChange)) { + props.onChange({ + ...controlState, + batch, + }); + } + }} + /> { return { + batch: defaults.batch.default, cfg: defaults.cfg.default, eta: defaults.eta.default, negativePrompt: defaults.negativePrompt.default,