fix(api): load replacement VAE from single file for SD v1/v2
This commit is contained in:
parent
56f19256b5
commit
c99481f484
|
@ -341,7 +341,6 @@ def convert_diffusion_diffusers(
|
|||
source,
|
||||
original_config_file=config_path,
|
||||
pipeline_class=pipe_class,
|
||||
vae_path=replace_vae,
|
||||
**pipe_args,
|
||||
).to(device, torch_dtype=dtype)
|
||||
elif hf:
|
||||
|
@ -355,6 +354,13 @@ def convert_diffusion_diffusers(
|
|||
logger.warning("pipeline source not found or not recognized: %s", source)
|
||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
||||
|
||||
if replace_vae is not None:
|
||||
vae_path = path.join(conversion.model_path, replace_vae)
|
||||
if replace_vae.endswith(".safetensors"):
|
||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||
else:
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
|
||||
optimize_pipeline(conversion, pipeline)
|
||||
|
||||
output_path = Path(dest_path)
|
||||
|
@ -507,19 +513,6 @@ def convert_diffusion_diffusers(
|
|||
del unet
|
||||
run_gc()
|
||||
|
||||
# VAE
|
||||
if replace_vae is not None:
|
||||
if replace_vae.startswith("."):
|
||||
logger.debug(
|
||||
"custom VAE appears to be a local path, making it relative to the model path"
|
||||
)
|
||||
replace_vae = path.join(conversion.model_path, replace_vae)
|
||||
|
||||
logger.info("loading custom VAE: %s", replace_vae)
|
||||
vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||
pipeline.vae = vae
|
||||
run_gc()
|
||||
|
||||
if single_vae:
|
||||
logger.debug("VAE config: %s", pipeline.vae.config)
|
||||
|
||||
|
|
Loading…
Reference in New Issue