load checkpoint properly
This commit is contained in:
parent
ca1b22d44d
commit
873276f1d0
|
@ -1302,6 +1302,7 @@ def extract_checkpoint(
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
else:
|
else:
|
||||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||||
|
logger.debug("loading custom VAE file: %s", vae_file)
|
||||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
|
||||||
|
|
||||||
|
|
|
@ -221,3 +221,5 @@ def load_tensor(name: str, map_location=None):
|
||||||
checkpoint = (
|
checkpoint = (
|
||||||
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return checkpoint
|
Loading…
Reference in New Issue