1
0
Fork 0

load checkpoint properly

This commit is contained in:
Sean Sube 2023-02-16 20:23:10 -06:00
parent ca1b22d44d
commit 873276f1d0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 3 additions and 0 deletions

View File

@ -1302,6 +1302,7 @@ def extract_checkpoint(
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
else: else:
vae_file = os.path.join(ctx.model_path, vae_file) vae_file = os.path.join(ctx.model_path, vae_file)
logger.debug("loading custom VAE file: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location) vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False) converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)

View File

@ -221,3 +221,5 @@ def load_tensor(name: str, map_location=None):
checkpoint = ( checkpoint = (
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
) )
return checkpoint