skip key removal for VAE tensors
This commit is contained in:
parent
9e479795fa
commit
cddbc87ca3
|
@ -634,14 +634,17 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||
return new_checkpoint, has_ema
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
if use_key:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
else:
|
||||
vae_state_dict[vae_key] = checkpoint.get(key)
|
||||
|
||||
new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
|
||||
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
|
||||
|
@ -1313,7 +1316,7 @@ def extract_checkpoint(
|
|||
else:
|
||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
|
Loading…
Reference in New Issue