feat(api): support custom VAE for diffusers models
This commit is contained in:
parent
388eb640c0
commit
d42de16a84
|
@ -1301,7 +1301,7 @@ def extract_checkpoint(
|
|||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
else:
|
||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||
logger.debug("loading custom VAE file: %s", vae_file)
|
||||
logger.debug("loading custom VAE: %s", vae_file)
|
||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from typing import Dict
|
|||
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionPipeline,
|
||||
|
@ -71,6 +72,7 @@ def convert_diffusion_stable(
|
|||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
|
||||
dtype = torch.float16 if ctx.half else torch.float32
|
||||
dest_path = path.join(ctx.model_path, name)
|
||||
|
@ -177,6 +179,11 @@ def convert_diffusion_stable(
|
|||
)
|
||||
del pipeline.unet
|
||||
|
||||
if replace_vae is not None:
|
||||
logger.debug("loading custom VAE: %s", replace_vae)
|
||||
vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||
pipeline.vae = vae
|
||||
|
||||
if single_vae:
|
||||
logger.debug("VAE config: %s", pipeline.vae.config)
|
||||
|
||||
|
|
Loading…
Reference in New Issue