fix(api): only fetch diffusion models if they have not already been converted (#398)
This commit is contained in:
parent
c9b1df9fdd
commit
9c1fcd16fa
|
@ -257,14 +257,12 @@ def convert_model_diffusion(conversion: ConversionContext, model):
|
|||
model["name"] = name
|
||||
|
||||
model_format = source_format(model)
|
||||
dest = fetch_model(conversion, name, model["source"], format=model_format)
|
||||
|
||||
pipeline = model.get("pipeline", "txt2img")
|
||||
converter = model_converters.get(pipeline)
|
||||
converted, dest = converter(
|
||||
conversion,
|
||||
model,
|
||||
dest,
|
||||
model_format,
|
||||
)
|
||||
|
||||
|
|
|
@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
|||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||
from ...models.cnet import UNet2DConditionModel_CNet
|
||||
from ...utils import run_gc
|
||||
from ..client import fetch_model
|
||||
from ..client.huggingface import HuggingfaceClient
|
||||
from ..utils import (
|
||||
RESOLVE_FORMATS,
|
||||
ConversionContext,
|
||||
|
@ -43,6 +45,7 @@ from ..utils import (
|
|||
is_torch_2_0,
|
||||
load_tensor,
|
||||
onnx_export,
|
||||
remove_prefix,
|
||||
)
|
||||
from .checkpoint import convert_extract_checkpoint
|
||||
|
||||
|
@ -267,14 +270,13 @@ def collate_cnet(cnet_path):
|
|||
def convert_diffusion_diffusers(
|
||||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
hf: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
name = model.get("name")
|
||||
source = model.get("source")
|
||||
|
||||
# optional
|
||||
config = model.get("config", None)
|
||||
|
@ -320,9 +322,11 @@ def convert_diffusion_diffusers(
|
|||
logger.info("ONNX model already exists, skipping")
|
||||
return (False, dest_path)
|
||||
|
||||
cache_path = fetch_model(conversion, name, source, format=format)
|
||||
|
||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||
v2, pipe_args = get_model_version(
|
||||
source, conversion.map_location, size=image_size, version=version
|
||||
cache_path, conversion.map_location, size=image_size, version=version
|
||||
)
|
||||
|
||||
is_inpainting = False
|
||||
|
@ -334,50 +338,58 @@ def convert_diffusion_diffusers(
|
|||
pipe_args["from_safetensors"] = True
|
||||
|
||||
torch_source = None
|
||||
if path.exists(source) and path.isdir(source):
|
||||
logger.debug("loading pipeline from diffusers directory: %s", source)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
elif path.exists(source) and path.isfile(source):
|
||||
if conversion.extract:
|
||||
logger.debug("extracting SD checkpoint to Torch models: %s", source)
|
||||
torch_source = convert_extract_checkpoint(
|
||||
conversion,
|
||||
source,
|
||||
f"{name}-torch",
|
||||
is_inpainting=is_inpainting,
|
||||
config_file=config,
|
||||
vae_file=replace_vae,
|
||||
)
|
||||
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
|
||||
if path.exists(cache_path):
|
||||
if path.isdir(cache_path):
|
||||
logger.debug("loading pipeline from diffusers directory: %s", source)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
torch_source,
|
||||
cache_path,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
elif path.isfile(source):
|
||||
if conversion.extract:
|
||||
logger.debug("extracting SD checkpoint to Torch models: %s", source)
|
||||
torch_source = convert_extract_checkpoint(
|
||||
conversion,
|
||||
source,
|
||||
f"{name}-torch",
|
||||
is_inpainting=is_inpainting,
|
||||
config_file=config,
|
||||
vae_file=replace_vae,
|
||||
)
|
||||
logger.debug(
|
||||
"loading pipeline from extracted checkpoint: %s", torch_source
|
||||
)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
torch_source,
|
||||
torch_dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
# VAE replacement already happened during extraction, skip
|
||||
replace_vae = None
|
||||
else:
|
||||
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||
source,
|
||||
original_config_file=config_path,
|
||||
pipeline_class=pipe_class,
|
||||
**pipe_args,
|
||||
).to(device, torch_dtype=dtype)
|
||||
elif hf:
|
||||
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
|
||||
# VAE replacement already happened during extraction, skip
|
||||
replace_vae = None
|
||||
else:
|
||||
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||
source,
|
||||
original_config_file=config_path,
|
||||
pipeline_class=pipe_class,
|
||||
**pipe_args,
|
||||
).to(device, torch_dtype=dtype)
|
||||
elif source.startswith(HuggingfaceClient.protocol):
|
||||
hf_path = remove_prefix(source, HuggingfaceClient.protocol)
|
||||
logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
source,
|
||||
hf_path,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
else:
|
||||
logger.warning("pipeline source not found or not recognized: %s", source)
|
||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
||||
logger.warning(
|
||||
"pipeline source not found and protocol not recognized: %s", source
|
||||
)
|
||||
raise ValueError(
|
||||
f"pipeline source not found and protocol not recognized: {source}"
|
||||
)
|
||||
|
||||
if replace_vae is not None:
|
||||
vae_path = path.join(conversion.model_path, replace_vae)
|
||||
|
|
|
@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
|
|||
from optimum.exporters.onnx import main_export
|
||||
|
||||
from ...constants import ONNX_MODEL
|
||||
from ..client import fetch_model
|
||||
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -19,14 +20,13 @@ logger = getLogger(__name__)
|
|||
def convert_diffusion_diffusers_xl(
|
||||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
hf: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
name = model.get("name")
|
||||
source = model.get("source")
|
||||
replace_vae = model.get("vae", None)
|
||||
|
||||
device = conversion.training_device
|
||||
|
@ -52,15 +52,16 @@ def convert_diffusion_diffusers_xl(
|
|||
|
||||
return (False, dest_path)
|
||||
|
||||
cache_path = fetch_model(conversion, name, model["source"], format=format)
|
||||
# safetensors -> diffusers directory with torch models
|
||||
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||
|
||||
if format == "safetensors":
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
source, use_safetensors=True
|
||||
cache_path, use_safetensors=True
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)
|
||||
|
||||
if replace_vae is not None:
|
||||
vae_path = path.join(conversion.model_path, replace_vae)
|
||||
|
|
Loading…
Reference in New Issue