From 2b8b59a39cf17112bbf031031e664ed2cb849e6b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 23 Dec 2023 21:42:36 -0600 Subject: [PATCH] make SD conversion more like SDXL --- api/onnx_web/convert/diffusion/diffusion.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 6367a639..a32228ce 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -24,9 +24,6 @@ from diffusers import ( StableDiffusionPipeline, StableDiffusionUpscalePipeline, ) -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - download_from_original_stable_diffusion_ckpt, -) from onnx import load_model, save_model from ...constants import ONNX_MODEL, ONNX_WEIGHTS @@ -388,6 +385,7 @@ def convert_diffusion_diffusers( return (False, dest_path) cache_path = fetch_model(conversion, name, source, format=format) + temp_path = path.join(conversion.cache_path, f"{name}-torch") pipe_class = CONVERT_PIPELINES.get(pipe_type) v2, pipe_args = get_model_version( @@ -417,9 +415,9 @@ def convert_diffusion_diffusers( torch_source = convert_extract_checkpoint( conversion, cache_path, - f"{name}-torch", + temp_path, is_inpainting=is_inpainting, - config_file=config, + config_file=config_path, vae_file=replace_vae, ) logger.debug( @@ -434,10 +432,9 @@ def convert_diffusion_diffusers( replace_vae = None else: logger.debug("loading pipeline from SD checkpoint: %s", source) - pipeline = download_from_original_stable_diffusion_ckpt( + pipeline = pipe_class.from_single_file( cache_path, original_config_file=config_path, - pipeline_class=pipe_class, **pipe_args, ).to(device, torch_dtype=dtype) elif source.startswith(HuggingfaceClient.protocol):