skip key removal for VAE tensors
This commit is contained in:
parent
9e479795fa
commit
cddbc87ca3
|
@ -634,14 +634,17 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||||
return new_checkpoint, has_ema
|
return new_checkpoint, has_ema
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True):
|
||||||
# extract state dict for VAE
|
# extract state dict for VAE
|
||||||
vae_state_dict = {}
|
vae_state_dict = {}
|
||||||
vae_key = "first_stage_model."
|
vae_key = "first_stage_model."
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
if use_key:
|
||||||
if key.startswith(vae_key):
|
if key.startswith(vae_key):
|
||||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
|
else:
|
||||||
|
vae_state_dict[vae_key] = checkpoint.get(key)
|
||||||
|
|
||||||
new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
|
new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
|
||||||
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
|
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
|
||||||
|
@ -1313,7 +1316,7 @@ def extract_checkpoint(
|
||||||
else:
|
else:
|
||||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||||
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
|
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
|
Loading…
Reference in New Issue