feat(api): add option for custom VAE to extras file (#145)
This commit is contained in:
parent
5e9dfa3465
commit
2c66cc58c5
|
@ -1144,7 +1144,8 @@ def extract_checkpoint(
|
||||||
extract_ema=False,
|
extract_ema=False,
|
||||||
train_unfrozen=False,
|
train_unfrozen=False,
|
||||||
is_512=True,
|
is_512=True,
|
||||||
config_file=None,
|
config_file: str =None,
|
||||||
|
vae_file: str =None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1306,7 +1307,12 @@ def extract_checkpoint(
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
logger.info("converting VAE")
|
logger.info("converting VAE")
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
|
||||||
|
if vae_file is None:
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
else:
|
||||||
|
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
@ -1427,7 +1433,7 @@ def convert_diffusion_original(
|
||||||
logger.info("torch pipeline already exists, reusing: %s", torch_path)
|
logger.info("torch pipeline already exists, reusing: %s", torch_path)
|
||||||
else:
|
else:
|
||||||
logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||||
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"))
|
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae"))
|
||||||
logger.info("converted original Diffusers checkpoint to Torch model")
|
logger.info("converted original Diffusers checkpoint to Torch model")
|
||||||
|
|
||||||
convert_diffusion_stable(ctx, model, working_name)
|
convert_diffusion_stable(ctx, model, working_name)
|
||||||
|
|
Loading…
Reference in New Issue