From 9c1fcd16fa31318b85aa5e42175c93e4421d1f89 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 13:52:52 -0600 Subject: [PATCH] fix(api): only fetch diffusion models if they have not already been converted (#398) --- api/onnx_web/convert/__main__.py | 2 - api/onnx_web/convert/diffusion/diffusion.py | 88 +++++++++++-------- .../convert/diffusion/diffusion_xl.py | 9 +- 3 files changed, 55 insertions(+), 44 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 34be1420..1a4f0241 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -257,14 +257,12 @@ def convert_model_diffusion(conversion: ConversionContext, model): model["name"] = name model_format = source_format(model) - dest = fetch_model(conversion, name, model["source"], format=model_format) pipeline = model.get("pipeline", "txt2img") converter = model_converters.get(pipeline) converted, dest = converter( conversion, model, - dest, model_format, ) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index dbb50610..aea560d9 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc +from ..client import fetch_model +from ..client.huggingface import HuggingfaceClient from ..utils import ( RESOLVE_FORMATS, ConversionContext, @@ -43,6 +45,7 @@ from ..utils import ( is_torch_2_0, load_tensor, onnx_export, + remove_prefix, ) from .checkpoint import convert_extract_checkpoint @@ -267,14 +270,13 @@ def collate_cnet(cnet_path): def convert_diffusion_diffusers( conversion: ConversionContext, model: Dict, - source: str, format: Optional[str], - hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ name = model.get("name") + source = model.get("source") # optional config = model.get("config", None) @@ -320,9 +322,11 @@ def convert_diffusion_diffusers( logger.info("ONNX model already exists, skipping") return (False, dest_path) + cache_path = fetch_model(conversion, name, source, format=format) + pipe_class = CONVERT_PIPELINES.get(pipe_type) v2, pipe_args = get_model_version( - source, conversion.map_location, size=image_size, version=version + cache_path, conversion.map_location, size=image_size, version=version ) is_inpainting = False @@ -334,50 +338,58 @@ def convert_diffusion_diffusers( pipe_args["from_safetensors"] = True torch_source = None - if path.exists(source) and path.isdir(source): - logger.debug("loading pipeline from diffusers directory: %s", source) - pipeline = pipe_class.from_pretrained( - source, - torch_dtype=dtype, - use_auth_token=conversion.token, - ).to(device) - elif path.exists(source) and path.isfile(source): - if conversion.extract: - logger.debug("extracting SD checkpoint to Torch models: %s", source) - torch_source = convert_extract_checkpoint( - conversion, - source, - f"{name}-torch", - is_inpainting=is_inpainting, - config_file=config, - vae_file=replace_vae, - ) - logger.debug("loading pipeline from extracted checkpoint: %s", torch_source) + if path.exists(cache_path): + if path.isdir(cache_path): + logger.debug("loading pipeline from diffusers directory: %s", source) pipeline = pipe_class.from_pretrained( - torch_source, + cache_path, torch_dtype=dtype, + use_auth_token=conversion.token, ).to(device) + elif path.isfile(source): + if conversion.extract: + logger.debug("extracting SD checkpoint to Torch models: %s", source) + torch_source = convert_extract_checkpoint( + conversion, + source, + f"{name}-torch", + is_inpainting=is_inpainting, + config_file=config, + vae_file=replace_vae, + ) + logger.debug( + "loading pipeline from extracted checkpoint: %s", torch_source + ) + pipeline = pipe_class.from_pretrained( + torch_source, + torch_dtype=dtype, + ).to(device) - # VAE replacement already happened during extraction, skip - replace_vae = None - else: - logger.debug("loading pipeline from SD checkpoint: %s", source) - pipeline = download_from_original_stable_diffusion_ckpt( - source, - original_config_file=config_path, - pipeline_class=pipe_class, - **pipe_args, - ).to(device, torch_dtype=dtype) - elif hf: - logger.debug("downloading pretrained model from Huggingface hub: %s", source) + # VAE replacement already happened during extraction, skip + replace_vae = None + else: + logger.debug("loading pipeline from SD checkpoint: %s", source) + pipeline = download_from_original_stable_diffusion_ckpt( + source, + original_config_file=config_path, + pipeline_class=pipe_class, + **pipe_args, + ).to(device, torch_dtype=dtype) + elif source.startswith(HuggingfaceClient.protocol): + hf_path = remove_prefix(source, HuggingfaceClient.protocol) + logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path) pipeline = pipe_class.from_pretrained( - source, + hf_path, torch_dtype=dtype, use_auth_token=conversion.token, ).to(device) else: - logger.warning("pipeline source not found or not recognized: %s", source) - raise ValueError(f"pipeline source not found or not recognized: {source}") + logger.warning( + "pipeline source not found and protocol not recognized: %s", source + ) + raise ValueError( + f"pipeline source not found and protocol not recognized: {source}" + ) if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 8370d302..d9319596 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16 from optimum.exporters.onnx import main_export from ...constants import ONNX_MODEL +from ..client import fetch_model from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext logger = getLogger(__name__) @@ -19,14 +20,13 @@ logger = getLogger(__name__) def convert_diffusion_diffusers_xl( conversion: ConversionContext, model: Dict, - source: str, format: Optional[str], - hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ name = model.get("name") + source = model.get("source") replace_vae = model.get("vae", None) device = conversion.training_device @@ -52,15 +52,16 @@ def convert_diffusion_diffusers_xl( return (False, dest_path) + cache_path = fetch_model(conversion, name, model["source"], format=format) # safetensors -> diffusers directory with torch models temp_path = path.join(conversion.cache_path, f"{name}-torch") if format == "safetensors": pipeline = StableDiffusionXLPipeline.from_single_file( - source, use_safetensors=True + cache_path, use_safetensors=True ) else: - pipeline = StableDiffusionXLPipeline.from_pretrained(source) + pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path) if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae)