1
0
Fork 0

feat(api): support custom VAE for diffusers models

This commit is contained in:
Sean Sube 2023-02-16 22:52:25 -06:00
parent 388eb640c0
commit d42de16a84
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 8 additions and 1 deletions

View File

@ -1301,7 +1301,7 @@ def extract_checkpoint(
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
else:
vae_file = os.path.join(ctx.model_path, vae_file)
logger.debug("loading custom VAE file: %s", vae_file)
logger.debug("loading custom VAE: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)

View File

@ -17,6 +17,7 @@ from typing import Dict
import torch
from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionPipeline,
@ -71,6 +72,7 @@ def convert_diffusion_stable(
name = model.get("name")
source = source or model.get("source")
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
dtype = torch.float16 if ctx.half else torch.float32
dest_path = path.join(ctx.model_path, name)
@ -177,6 +179,11 @@ def convert_diffusion_stable(
)
del pipeline.unet
if replace_vae is not None:
logger.debug("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.vae = vae
if single_vae:
logger.debug("VAE config: %s", pipeline.vae.config)