1
0
Fork 0

feat(api): add option for custom VAE to extras file (#145)

This commit is contained in:
Sean Sube 2023-02-16 18:53:50 -06:00
parent 5e9dfa3465
commit 2c66cc58c5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 3 deletions

View File

@ -1144,7 +1144,8 @@ def extract_checkpoint(
extract_ema=False,
train_unfrozen=False,
is_512=True,
config_file=None,
config_file: str =None,
vae_file: str =None,
):
"""
@ -1306,7 +1307,12 @@ def extract_checkpoint(
# Convert the VAE model.
logger.info("converting VAE")
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
if vae_file is None:
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
else:
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
@ -1427,7 +1433,7 @@ def convert_diffusion_original(
logger.info("torch pipeline already exists, reusing: %s", torch_path)
else:
logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"))
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae"))
logger.info("converted original Diffusers checkpoint to Torch model")
convert_diffusion_stable(ctx, model, working_name)