diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 52d69960..b11d1458 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -1301,7 +1301,7 @@ def extract_checkpoint( converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) else: vae_file = os.path.join(ctx.model_path, vae_file) - logger.debug("loading custom VAE file: %s", vae_file) + logger.debug("loading custom VAE: %s", vae_file) vae_checkpoint = load_tensor(vae_file, map_location=map_location) converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False) diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index 135de80a..b6dec0a1 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -17,6 +17,7 @@ from typing import Dict import torch from diffusers import ( + AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, @@ -71,6 +72,7 @@ def convert_diffusion_stable( name = model.get("name") source = source or model.get("source") single_vae = model.get("single_vae") + replace_vae = model.get("vae") dtype = torch.float16 if ctx.half else torch.float32 dest_path = path.join(ctx.model_path, name) @@ -177,6 +179,11 @@ def convert_diffusion_stable( ) del pipeline.unet + if replace_vae is not None: + logger.debug("loading custom VAE: %s", replace_vae) + vae = AutoencoderKL.from_pretrained(replace_vae) + pipeline.vae = vae + if single_vae: logger.debug("VAE config: %s", pipeline.vae.config)