From 2c8ba67ecb6881299514f721d5d5b28ded74b466 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 21 Aug 2023 20:46:05 -0500 Subject: [PATCH] make highres work with SDXL and use a full-size txt2img stage --- api/onnx_web/chain/blend_img2img.py | 4 ++-- api/onnx_web/diffusers/run.py | 2 +- api/onnx_web/params.py | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 9e6b428b..34986d52 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -57,7 +57,7 @@ class BlendImg2ImgStage(BaseStage): pipe_params["strength"] = strength elif params.is_panorama(): pipe_params["strength"] = strength - elif pipe_type == "img2img": + elif pipe_type == "img2img" or pipe_type == "img2img-sdxl": pipe_params["strength"] = strength elif pipe_type == "pix2pix": pipe_params["image_guidance_scale"] = strength @@ -83,7 +83,7 @@ class BlendImg2ImgStage(BaseStage): pipe, prompt_pairs, params.batch, params.do_cfg() ) - if not params.xl(): + if not params.is_xl(): pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 27a37130..5b1aae78 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -43,7 +43,7 @@ def run_txt2img_pipeline( highres: HighresParams, ) -> None: # if using panorama, the pipeline will tile itself (views) - if params.is_panorama(): + if params.is_panorama() or params.is_xl(): tile_size = max(params.tiles, size.width, size.height) else: tile_size = params.tiles diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index d8bd3364..08cfb159 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -261,6 +261,8 @@ class ImageParams: if group == "img2img": if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "pix2pix"]: return pipeline + elif pipeline == "txt2img-sdxl": + return "img2img-sdxl" elif group == "inpaint": if pipeline in ["controlnet", "lpw", "panorama"]: return pipeline