feat(api): add option for custom VAE to extras file (#145)
This commit is contained in:
parent
5e9dfa3465
commit
2c66cc58c5
|
@ -1144,7 +1144,8 @@ def extract_checkpoint(
|
|||
extract_ema=False,
|
||||
train_unfrozen=False,
|
||||
is_512=True,
|
||||
config_file=None,
|
||||
config_file: str =None,
|
||||
vae_file: str =None,
|
||||
):
|
||||
"""
|
||||
|
||||
|
@ -1306,7 +1307,12 @@ def extract_checkpoint(
|
|||
# Convert the VAE model.
|
||||
logger.info("converting VAE")
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
|
||||
if vae_file is None:
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
else:
|
||||
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
@ -1427,7 +1433,7 @@ def convert_diffusion_original(
|
|||
logger.info("torch pipeline already exists, reusing: %s", torch_path)
|
||||
else:
|
||||
logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"))
|
||||
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae"))
|
||||
logger.info("converted original Diffusers checkpoint to Torch model")
|
||||
|
||||
convert_diffusion_stable(ctx, model, working_name)
|
||||
|
|
Loading…
Reference in New Issue