From c1189aad965766fe8d89f1f87046976b7ff1ee3f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Feb 2023 23:29:26 -0600 Subject: [PATCH] feat: add eta parameter (fixes #194) --- api/onnx_web/chain/blend_inpaint.py | 2 + .../chain/upscale_stable_diffusion.py | 1 + api/onnx_web/diffusion/run.py | 4 ++ api/onnx_web/params.py | 6 ++- api/onnx_web/serve.py | 8 +++ api/params.json | 6 +++ gui/src/client.ts | 2 + gui/src/components/control/ImageControl.tsx | 52 +++++++++++++------ gui/src/state.ts | 1 + 9 files changed, 64 insertions(+), 18 deletions(-) diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index a19cb8b7..7c3b32a7 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -91,6 +91,7 @@ def blend_inpaint( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, + eta=params.eta, callback=callback, ) else: @@ -106,6 +107,7 @@ def blend_inpaint( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, + eta=params.eta, callback=callback, ) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index c5356e11..00a24a08 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -88,5 +88,6 @@ def upscale_stable_diffusion( source, generator=generator, num_inference_steps=params.steps, + eta=params.eta, callback=callback, ).images[0] diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index ebd56cb6..5393fa66 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -51,6 +51,7 @@ def run_txt2img_pipeline( latents=latents, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, + eta=params.eta, callback=progress, ) else: @@ -64,6 +65,7 @@ def run_txt2img_pipeline( latents=latents, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, + eta=params.eta, callback=progress, ) @@ -119,6 +121,7 @@ def run_img2img_pipeline( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, strength=strength, + eta=params.eta, callback=progress, ) else: @@ -131,6 +134,7 @@ def run_img2img_pipeline( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, strength=strength, + eta=params.eta, callback=progress, ) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index c59f7be3..16528000 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -154,7 +154,8 @@ class ImageParams: steps: int, seed: int, negative_prompt: Optional[str] = None, - lpw: Optional[bool] = False, + lpw: bool = False, + eta: float = 0.0, ) -> None: self.model = model self.scheduler = scheduler @@ -164,6 +165,7 @@ class ImageParams: self.seed = seed self.steps = steps self.lpw = lpw or False + self.eta = eta def tojson(self) -> Dict[str, Optional[Param]]: return { @@ -175,6 +177,7 @@ class ImageParams: "seed": self.seed, "steps": self.steps, "lpw": self.lpw, + "eta": self.eta, } def with_args(self, **kwargs): @@ -187,6 +190,7 @@ class ImageParams: kwargs.get("seed", self.seed), kwargs.get("negative_prompt", self.negative_prompt), kwargs.get("lpw", self.lpw), + kwargs.get("eta", self.eta), ) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 38ef1981..6e7ced0c 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -172,6 +172,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: get_config_value("cfg", "max"), get_config_value("cfg", "min"), ) + eta = get_and_clamp_float( + request.args, + "eta", + get_config_value("eta"), + get_config_value("eta", "max"), + get_config_value("eta", "min"), + ) steps = get_and_clamp_int( request.args, "steps", @@ -220,6 +227,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: cfg, steps, seed, + eta=eta, lpw=lpw, negative_prompt=negative_prompt, ) diff --git a/api/params.json b/api/params.json index 48b8bdf8..7dda2d2c 100644 --- a/api/params.json +++ b/api/params.json @@ -22,6 +22,12 @@ "max": 1, "step": 0.1 }, + "eta": { + "default": 0.0, + "min": 0, + "max": 1, + "step": 0.1 + }, "faceOutscale": { "default": 1, "min": 1, diff --git a/gui/src/client.ts b/gui/src/client.ts index 5155e0f2..940dbcc3 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -45,6 +45,7 @@ export interface BaseImgParams { cfg: number; steps: number; seed: number; + eta: number; } /** @@ -279,6 +280,7 @@ export function makeApiUrl(root: string, ...path: Array) { export function makeImageURL(root: string, type: string, params: BaseImgParams): URL { const url = makeApiUrl(root, type); url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); + url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT)); url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER)); if (doesExist(params.scheduler)) { diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index e1784031..6b80e552 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -34,23 +34,41 @@ export function ImageControl(props: ImageControlProps) { }); return - { - if (doesExist(props.onChange)) { - props.onChange({ - ...controlState, - scheduler: value, - }); - } - }} - /> + + { + if (doesExist(props.onChange)) { + props.onChange({ + ...controlState, + scheduler: value, + }); + } + }} + /> + { + if (doesExist(props.onChange)) { + props.onChange({ + ...controlState, + eta, + }); + } + }} + /> + { return { cfg: defaults.cfg.default, + eta: defaults.eta.default, negativePrompt: defaults.negativePrompt.default, prompt: defaults.prompt.default, scheduler: defaults.scheduler.default,