From c99481f4848d18028d2d6fcfb591fac08fa94016 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Sep 2023 09:49:50 -0500 Subject: [PATCH] fix(api): load replacement VAE from single file for SD v1/v2 --- api/onnx_web/convert/diffusion/diffusion.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 93bc96a1..a14f2c95 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -341,7 +341,6 @@ def convert_diffusion_diffusers( source, original_config_file=config_path, pipeline_class=pipe_class, - vae_path=replace_vae, **pipe_args, ).to(device, torch_dtype=dtype) elif hf: @@ -355,6 +354,13 @@ def convert_diffusion_diffusers( logger.warning("pipeline source not found or not recognized: %s", source) raise ValueError(f"pipeline source not found or not recognized: {source}") + if replace_vae is not None: + vae_path = path.join(conversion.model_path, replace_vae) + if replace_vae.endswith(".safetensors"): + pipeline.vae = AutoencoderKL.from_single_file(vae_path) + else: + pipeline.vae = AutoencoderKL.from_pretrained(vae_path) + optimize_pipeline(conversion, pipeline) output_path = Path(dest_path) @@ -507,19 +513,6 @@ def convert_diffusion_diffusers( del unet run_gc() - # VAE - if replace_vae is not None: - if replace_vae.startswith("."): - logger.debug( - "custom VAE appears to be a local path, making it relative to the model path" - ) - replace_vae = path.join(conversion.model_path, replace_vae) - - logger.info("loading custom VAE: %s", replace_vae) - vae = AutoencoderKL.from_pretrained(replace_vae) - pipeline.vae = vae - run_gc() - if single_vae: logger.debug("VAE config: %s", pipeline.vae.config)