1
0
Fork 0

fix(api): load replacement VAE from single file for SD v1/v2

This commit is contained in:
Sean Sube 2023-09-24 09:49:50 -05:00
parent 56f19256b5
commit c99481f484
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 7 additions and 14 deletions

View File

@ -341,7 +341,6 @@ def convert_diffusion_diffusers(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
vae_path=replace_vae,
**pipe_args,
).to(device, torch_dtype=dtype)
elif hf:
@ -355,6 +354,13 @@ def convert_diffusion_diffusers(
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if replace_vae.endswith(".safetensors"):
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
optimize_pipeline(conversion, pipeline)
output_path = Path(dest_path)
@ -507,19 +513,6 @@ def convert_diffusion_diffusers(
del unet
run_gc()
# VAE
if replace_vae is not None:
if replace_vae.startswith("."):
logger.debug(
"custom VAE appears to be a local path, making it relative to the model path"
)
replace_vae = path.join(conversion.model_path, replace_vae)
logger.info("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.vae = vae
run_gc()
if single_vae:
logger.debug("VAE config: %s", pipeline.vae.config)