fix(api): only fetch diffusion models if they have not already been converted (#398)
This commit is contained in:
parent
c9b1df9fdd
commit
9c1fcd16fa
|
@ -257,14 +257,12 @@ def convert_model_diffusion(conversion: ConversionContext, model):
|
||||||
model["name"] = name
|
model["name"] = name
|
||||||
|
|
||||||
model_format = source_format(model)
|
model_format = source_format(model)
|
||||||
dest = fetch_model(conversion, name, model["source"], format=model_format)
|
|
||||||
|
|
||||||
pipeline = model.get("pipeline", "txt2img")
|
pipeline = model.get("pipeline", "txt2img")
|
||||||
converter = model_converters.get(pipeline)
|
converter = model_converters.get(pipeline)
|
||||||
converted, dest = converter(
|
converted, dest = converter(
|
||||||
conversion,
|
conversion,
|
||||||
model,
|
model,
|
||||||
dest,
|
|
||||||
model_format,
|
model_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||||
from ...models.cnet import UNet2DConditionModel_CNet
|
from ...models.cnet import UNet2DConditionModel_CNet
|
||||||
from ...utils import run_gc
|
from ...utils import run_gc
|
||||||
|
from ..client import fetch_model
|
||||||
|
from ..client.huggingface import HuggingfaceClient
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
RESOLVE_FORMATS,
|
RESOLVE_FORMATS,
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
|
@ -43,6 +45,7 @@ from ..utils import (
|
||||||
is_torch_2_0,
|
is_torch_2_0,
|
||||||
load_tensor,
|
load_tensor,
|
||||||
onnx_export,
|
onnx_export,
|
||||||
|
remove_prefix,
|
||||||
)
|
)
|
||||||
from .checkpoint import convert_extract_checkpoint
|
from .checkpoint import convert_extract_checkpoint
|
||||||
|
|
||||||
|
@ -267,14 +270,13 @@ def collate_cnet(cnet_path):
|
||||||
def convert_diffusion_diffusers(
|
def convert_diffusion_diffusers(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
|
||||||
format: Optional[str],
|
format: Optional[str],
|
||||||
hf: bool = False,
|
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
|
source = model.get("source")
|
||||||
|
|
||||||
# optional
|
# optional
|
||||||
config = model.get("config", None)
|
config = model.get("config", None)
|
||||||
|
@ -320,9 +322,11 @@ def convert_diffusion_diffusers(
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
|
cache_path = fetch_model(conversion, name, source, format=format)
|
||||||
|
|
||||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||||
v2, pipe_args = get_model_version(
|
v2, pipe_args = get_model_version(
|
||||||
source, conversion.map_location, size=image_size, version=version
|
cache_path, conversion.map_location, size=image_size, version=version
|
||||||
)
|
)
|
||||||
|
|
||||||
is_inpainting = False
|
is_inpainting = False
|
||||||
|
@ -334,14 +338,15 @@ def convert_diffusion_diffusers(
|
||||||
pipe_args["from_safetensors"] = True
|
pipe_args["from_safetensors"] = True
|
||||||
|
|
||||||
torch_source = None
|
torch_source = None
|
||||||
if path.exists(source) and path.isdir(source):
|
if path.exists(cache_path):
|
||||||
|
if path.isdir(cache_path):
|
||||||
logger.debug("loading pipeline from diffusers directory: %s", source)
|
logger.debug("loading pipeline from diffusers directory: %s", source)
|
||||||
pipeline = pipe_class.from_pretrained(
|
pipeline = pipe_class.from_pretrained(
|
||||||
source,
|
cache_path,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
use_auth_token=conversion.token,
|
use_auth_token=conversion.token,
|
||||||
).to(device)
|
).to(device)
|
||||||
elif path.exists(source) and path.isfile(source):
|
elif path.isfile(source):
|
||||||
if conversion.extract:
|
if conversion.extract:
|
||||||
logger.debug("extracting SD checkpoint to Torch models: %s", source)
|
logger.debug("extracting SD checkpoint to Torch models: %s", source)
|
||||||
torch_source = convert_extract_checkpoint(
|
torch_source = convert_extract_checkpoint(
|
||||||
|
@ -352,7 +357,9 @@ def convert_diffusion_diffusers(
|
||||||
config_file=config,
|
config_file=config,
|
||||||
vae_file=replace_vae,
|
vae_file=replace_vae,
|
||||||
)
|
)
|
||||||
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
|
logger.debug(
|
||||||
|
"loading pipeline from extracted checkpoint: %s", torch_source
|
||||||
|
)
|
||||||
pipeline = pipe_class.from_pretrained(
|
pipeline = pipe_class.from_pretrained(
|
||||||
torch_source,
|
torch_source,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
|
@ -368,16 +375,21 @@ def convert_diffusion_diffusers(
|
||||||
pipeline_class=pipe_class,
|
pipeline_class=pipe_class,
|
||||||
**pipe_args,
|
**pipe_args,
|
||||||
).to(device, torch_dtype=dtype)
|
).to(device, torch_dtype=dtype)
|
||||||
elif hf:
|
elif source.startswith(HuggingfaceClient.protocol):
|
||||||
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
|
hf_path = remove_prefix(source, HuggingfaceClient.protocol)
|
||||||
|
logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path)
|
||||||
pipeline = pipe_class.from_pretrained(
|
pipeline = pipe_class.from_pretrained(
|
||||||
source,
|
hf_path,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
use_auth_token=conversion.token,
|
use_auth_token=conversion.token,
|
||||||
).to(device)
|
).to(device)
|
||||||
else:
|
else:
|
||||||
logger.warning("pipeline source not found or not recognized: %s", source)
|
logger.warning(
|
||||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
"pipeline source not found and protocol not recognized: %s", source
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"pipeline source not found and protocol not recognized: {source}"
|
||||||
|
)
|
||||||
|
|
||||||
if replace_vae is not None:
|
if replace_vae is not None:
|
||||||
vae_path = path.join(conversion.model_path, replace_vae)
|
vae_path = path.join(conversion.model_path, replace_vae)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||||
from optimum.exporters.onnx import main_export
|
from optimum.exporters.onnx import main_export
|
||||||
|
|
||||||
from ...constants import ONNX_MODEL
|
from ...constants import ONNX_MODEL
|
||||||
|
from ..client import fetch_model
|
||||||
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -19,14 +20,13 @@ logger = getLogger(__name__)
|
||||||
def convert_diffusion_diffusers_xl(
|
def convert_diffusion_diffusers_xl(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
|
||||||
format: Optional[str],
|
format: Optional[str],
|
||||||
hf: bool = False,
|
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
|
source = model.get("source")
|
||||||
replace_vae = model.get("vae", None)
|
replace_vae = model.get("vae", None)
|
||||||
|
|
||||||
device = conversion.training_device
|
device = conversion.training_device
|
||||||
|
@ -52,15 +52,16 @@ def convert_diffusion_diffusers_xl(
|
||||||
|
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
|
cache_path = fetch_model(conversion, name, model["source"], format=format)
|
||||||
# safetensors -> diffusers directory with torch models
|
# safetensors -> diffusers directory with torch models
|
||||||
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||||
|
|
||||||
if format == "safetensors":
|
if format == "safetensors":
|
||||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||||
source, use_safetensors=True
|
cache_path, use_safetensors=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
|
pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)
|
||||||
|
|
||||||
if replace_vae is not None:
|
if replace_vae is not None:
|
||||||
vae_path = path.join(conversion.model_path, replace_vae)
|
vae_path = path.join(conversion.model_path, replace_vae)
|
||||||
|
|
Loading…
Reference in New Issue