diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index b1ccae69..602e698e 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -1144,7 +1144,8 @@ def extract_checkpoint( extract_ema=False, train_unfrozen=False, is_512=True, - config_file=None, + config_file: str =None, + vae_file: str =None, ): """ @@ -1306,7 +1307,12 @@ def extract_checkpoint( # Convert the VAE model. logger.info("converting VAE") vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if vae_file is None: + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + else: + vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu") + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) @@ -1427,7 +1433,7 @@ def convert_diffusion_original( logger.info("torch pipeline already exists, reusing: %s", torch_path) else: logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path) - extract_checkpoint(ctx, torch_name, source, config_file=model.get("config")) + extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae")) logger.info("converted original Diffusers checkpoint to Torch model") convert_diffusion_stable(ctx, model, working_name)