feat(api): support custom VAE for diffusers models
This commit is contained in:
parent
388eb640c0
commit
d42de16a84
|
@ -1301,7 +1301,7 @@ def extract_checkpoint(
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
else:
|
else:
|
||||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||||
logger.debug("loading custom VAE file: %s", vae_file)
|
logger.debug("loading custom VAE: %s", vae_file)
|
||||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
OnnxRuntimeModel,
|
OnnxRuntimeModel,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
|
@ -71,6 +72,7 @@ def convert_diffusion_stable(
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
source = source or model.get("source")
|
source = source or model.get("source")
|
||||||
single_vae = model.get("single_vae")
|
single_vae = model.get("single_vae")
|
||||||
|
replace_vae = model.get("vae")
|
||||||
|
|
||||||
dtype = torch.float16 if ctx.half else torch.float32
|
dtype = torch.float16 if ctx.half else torch.float32
|
||||||
dest_path = path.join(ctx.model_path, name)
|
dest_path = path.join(ctx.model_path, name)
|
||||||
|
@ -177,6 +179,11 @@ def convert_diffusion_stable(
|
||||||
)
|
)
|
||||||
del pipeline.unet
|
del pipeline.unet
|
||||||
|
|
||||||
|
if replace_vae is not None:
|
||||||
|
logger.debug("loading custom VAE: %s", replace_vae)
|
||||||
|
vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||||
|
pipeline.vae = vae
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.debug("VAE config: %s", pipeline.vae.config)
|
logger.debug("VAE config: %s", pipeline.vae.config)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue