From 44004730ea6976c8898a1ab3a5738cc51ef796df Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 20 Aug 2023 22:28:08 -0500 Subject: [PATCH] skip loading component models for XL --- api/onnx_web/chain/blend_img2img.py | 16 +++++++++------- api/onnx_web/chain/source_txt2img.py | 6 ++++-- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/diffusers/load.py | 23 ++++++++++++++--------- api/onnx_web/diffusers/run.py | 3 +-- api/onnx_web/params.py | 11 ++++++++++- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 572ac3d3..9e6b428b 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -51,20 +51,20 @@ class BlendImg2ImgStage(BaseStage): ) pipe_params = {} - if pipe_type == "controlnet": + if params.is_control(): pipe_params["controlnet_conditioning_scale"] = strength + elif params.is_lpw(): + pipe_params["strength"] = strength + elif params.is_panorama(): + pipe_params["strength"] = strength elif pipe_type == "img2img": pipe_params["strength"] = strength - elif pipe_type == "lpw": - pipe_params["strength"] = strength - elif pipe_type == "panorama": - pipe_params["strength"] = strength elif pipe_type == "pix2pix": pipe_params["image_guidance_scale"] = strength outputs = [] for source in sources: - if params.lpw(): + if params.is_lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) result = pipe.img2img( @@ -82,7 +82,9 @@ class BlendImg2ImgStage(BaseStage): prompt_embeds = encode_prompt( pipe, prompt_pairs, params.batch, params.do_cfg() ) - pipe.unet.set_prompts(prompt_embeds) + + if not params.xl(): + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index b865e2e7..d32f7342 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -74,7 +74,7 @@ class SourceTxt2ImgStage(BaseStage): loras=loras, ) - if params.lpw(): + if params.is_lpw(): logger.debug("using LPW pipeline for txt2img") rng = torch.manual_seed(params.seed) result = pipe.text2img( @@ -95,7 +95,9 @@ class SourceTxt2ImgStage(BaseStage): prompt_embeds = encode_prompt( pipe, prompt_pairs, params.batch, params.do_cfg() ) - pipe.unet.set_prompts(prompt_embeds) + + if not params.is_xl(): + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 21b8e544..71de2629 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -81,7 +81,7 @@ class UpscaleOutpaintStage(BaseStage): else: latents = get_tile_latents(latents, params.seed, latent_size, dims) - if params.lpw(): + if params.is_lpw(): logger.debug("using LPW pipeline for inpaint") rng = torch.manual_seed(params.seed) result = pipe.inpaint( diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 8722056f..a564af70 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -171,7 +171,7 @@ def load_pipeline( unet_type = "unet" # ControlNet component - if pipeline == "controlnet" and params.control is not None: + if params.is_control() and params.control is not None: cnet_path = path.join( server.model_path, "control", f"{params.control.name}.onnx" ) @@ -282,7 +282,7 @@ def load_pipeline( ) # make sure a UNet has been loaded - if "unet" not in components: + if not params.is_xl() and "unet" not in components: unet = path.join(model, unet_type, ONNX_MODEL) logger.debug("loading UNet (%s) from %s", unet_type, unet) components["unet"] = OnnxRuntimeModel( @@ -298,7 +298,7 @@ def load_pipeline( vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL) - if path.exists(vae): + if not params.is_xl() and path.exists(vae): logger.debug("loading VAE from %s", vae) components["vae"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( @@ -307,7 +307,9 @@ def load_pipeline( sess_options=device.sess_options(), ) ) - elif path.exists(vae_decoder) and path.exists(vae_encoder): + elif ( + not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder) + ): logger.debug("loading VAE decoder from %s", vae_decoder) components["vae_decoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( @@ -327,7 +329,7 @@ def load_pipeline( ) # additional options for panorama pipeline - if pipeline == "panorama": + if params.is_panorama(): components["window"] = params.tiles // 8 components["stride"] = params.stride // 8 @@ -346,18 +348,21 @@ def load_pipeline( pipe.set_progress_bar_config(disable=True) optimize_pipeline(server, pipe) - patch_pipeline(server, pipe, pipeline, pipeline_class, params) + + if not params.is_xl(): + patch_pipeline(server, pipe, pipeline, pipeline_class, params) server.cache.set(ModelTypes.diffusion, pipe_key, pipe) server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"]) - if hasattr(pipe, "vae_decoder"): + if not params.is_xl() and hasattr(pipe, "vae_decoder"): pipe.vae_decoder.set_tiled(tiled=params.tiled_vae) - if hasattr(pipe, "vae_encoder"): + + if not params.is_xl() and hasattr(pipe, "vae_encoder"): pipe.vae_encoder.set_tiled(tiled=params.tiled_vae) # update panorama params - if pipeline == "panorama": + if params.is_panorama(): latent_window = params.tiles // 8 latent_stride = params.stride // 8 diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 0b601b67..27a37130 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -43,8 +43,7 @@ def run_txt2img_pipeline( highres: HighresParams, ) -> None: # if using panorama, the pipeline will tile itself (views) - pipe_type = params.get_valid_pipeline("txt2img") - if pipe_type == "panorama": + if params.is_panorama(): tile_size = max(params.tiles, size.width, size.height) else: tile_size = params.tiles diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 15a104fc..138e4a0e 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -271,9 +271,18 @@ class ImageParams: logger.debug("pipeline %s is not valid for %s", pipeline, group) return group - def lpw(self): + def is_control(self): + return self.pipeline == "controlnet" + + def is_lpw(self): return self.pipeline == "lpw" + def is_panorama(self): + return self.pipeline == "panorama" + + def is_xl(self): + return self.pipeline.endswith("-sdxl") + def tojson(self) -> Dict[str, Optional[Param]]: return { "model": self.model,