1
0
Fork 0

skip key removal for VAE tensors

This commit is contained in:
Sean Sube 2023-02-16 19:17:00 -06:00
parent 9e479795fa
commit cddbc87ca3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 7 additions and 4 deletions

View File

@ -634,14 +634,17 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint, has_ema return new_checkpoint, has_ema
def convert_ldm_vae_checkpoint(checkpoint, config): def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True):
# extract state dict for VAE # extract state dict for VAE
vae_state_dict = {} vae_state_dict = {}
vae_key = "first_stage_model." vae_key = "first_stage_model."
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
for key in keys: for key in keys:
if key.startswith(vae_key): if use_key:
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
else:
vae_state_dict[vae_key] = checkpoint.get(key)
new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
@ -1313,7 +1316,7 @@ def extract_checkpoint(
else: else:
vae_file = os.path.join(ctx.model_path, vae_file) vae_file = os.path.join(ctx.model_path, vae_file)
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu") vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)