diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 4fae214b..42a818f0 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -634,14 +634,17 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False return new_checkpoint, has_ema -def convert_ldm_vae_checkpoint(checkpoint, config): +def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True): # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." keys = list(checkpoint.keys()) for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + if use_key: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + else: + vae_state_dict[vae_key] = checkpoint.get(key) new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], @@ -1313,7 +1316,7 @@ def extract_checkpoint( else: vae_file = os.path.join(ctx.model_path, vae_file) vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu") - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint)