diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index e651e7f5..27e72a63 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -32,8 +32,9 @@ def run_txt2img_pipeline( ) -> None: # TODO: add to params highres_scale = 4 - highres_steps = 15 - highres_strength = 0.5 + highres_steps = 25 + highres_strength = 0.2 + highres_steps_post = int((params.steps - highres_steps) / highres_strength) latents = get_latents_from_seed(params.seed, size, batch=params.batch) @@ -87,8 +88,9 @@ def run_txt2img_pipeline( for image, output in zip(result.images, outputs): if highres_scale > 1: + highres_progress = ChainProgress.from_progress(progress) # load img2img pipeline once - highpipe = load_pipeline( + highres_pipe = load_pipeline( server, OnnxStableDiffusionImg2ImgPipeline, params.model, @@ -104,32 +106,32 @@ def run_txt2img_pipeline( if params.lpw: logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) - result = highpipe.img2img( + result = highres_pipe.img2img( tile, params.prompt, generator=rng, guidance_scale=params.cfg, negative_prompt=params.negative_prompt, num_images_per_prompt=1, - num_inference_steps=params.steps - highres_steps, + num_inference_steps=highres_steps_post, strength=highres_strength, eta=params.eta, - callback=progress, + callback=highres_progress, ) return result.images[0] else: rng = np.random.RandomState(params.seed) - result = highpipe( + result = highres_pipe( params.prompt, tile, generator=rng, guidance_scale=params.cfg, negative_prompt=params.negative_prompt, num_images_per_prompt=1, - num_inference_steps=params.steps - highres_steps, + num_inference_steps=highres_steps_post, strength=highres_strength, eta=params.eta, - callback=progress, + callback=highres_progress, ) return result.images[0]