1
0
Fork 0

fix(api): only fetch diffusion models if they have not already been converted (#398)

This commit is contained in:
Sean Sube 2023-12-10 13:52:52 -06:00
parent c9b1df9fdd
commit 9c1fcd16fa
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 55 additions and 44 deletions

View File

@ -257,14 +257,12 @@ def convert_model_diffusion(conversion: ConversionContext, model):
model["name"] = name model["name"] = name
model_format = source_format(model) model_format = source_format(model)
dest = fetch_model(conversion, name, model["source"], format=model_format)
pipeline = model.get("pipeline", "txt2img") pipeline = model.get("pipeline", "txt2img")
converter = model_converters.get(pipeline) converter = model_converters.get(pipeline)
converted, dest = converter( converted, dest = converter(
conversion, conversion,
model, model,
dest,
model_format, model_format,
) )

View File

@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet from ...models.cnet import UNet2DConditionModel_CNet
from ...utils import run_gc from ...utils import run_gc
from ..client import fetch_model
from ..client.huggingface import HuggingfaceClient
from ..utils import ( from ..utils import (
RESOLVE_FORMATS, RESOLVE_FORMATS,
ConversionContext, ConversionContext,
@ -43,6 +45,7 @@ from ..utils import (
is_torch_2_0, is_torch_2_0,
load_tensor, load_tensor,
onnx_export, onnx_export,
remove_prefix,
) )
from .checkpoint import convert_extract_checkpoint from .checkpoint import convert_extract_checkpoint
@ -267,14 +270,13 @@ def collate_cnet(cnet_path):
def convert_diffusion_diffusers( def convert_diffusion_diffusers(
conversion: ConversionContext, conversion: ConversionContext,
model: Dict, model: Dict,
source: str,
format: Optional[str], format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
name = model.get("name") name = model.get("name")
source = model.get("source")
# optional # optional
config = model.get("config", None) config = model.get("config", None)
@ -320,9 +322,11 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping") logger.info("ONNX model already exists, skipping")
return (False, dest_path) return (False, dest_path)
cache_path = fetch_model(conversion, name, source, format=format)
pipe_class = CONVERT_PIPELINES.get(pipe_type) pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version( v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version cache_path, conversion.map_location, size=image_size, version=version
) )
is_inpainting = False is_inpainting = False
@ -334,50 +338,58 @@ def convert_diffusion_diffusers(
pipe_args["from_safetensors"] = True pipe_args["from_safetensors"] = True
torch_source = None torch_source = None
if path.exists(source) and path.isdir(source): if path.exists(cache_path):
logger.debug("loading pipeline from diffusers directory: %s", source) if path.isdir(cache_path):
pipeline = pipe_class.from_pretrained( logger.debug("loading pipeline from diffusers directory: %s", source)
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
elif path.exists(source) and path.isfile(source):
if conversion.extract:
logger.debug("extracting SD checkpoint to Torch models: %s", source)
torch_source = convert_extract_checkpoint(
conversion,
source,
f"{name}-torch",
is_inpainting=is_inpainting,
config_file=config,
vae_file=replace_vae,
)
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
pipeline = pipe_class.from_pretrained( pipeline = pipe_class.from_pretrained(
torch_source, cache_path,
torch_dtype=dtype, torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device) ).to(device)
elif path.isfile(source):
if conversion.extract:
logger.debug("extracting SD checkpoint to Torch models: %s", source)
torch_source = convert_extract_checkpoint(
conversion,
source,
f"{name}-torch",
is_inpainting=is_inpainting,
config_file=config,
vae_file=replace_vae,
)
logger.debug(
"loading pipeline from extracted checkpoint: %s", torch_source
)
pipeline = pipe_class.from_pretrained(
torch_source,
torch_dtype=dtype,
).to(device)
# VAE replacement already happened during extraction, skip # VAE replacement already happened during extraction, skip
replace_vae = None replace_vae = None
else: else:
logger.debug("loading pipeline from SD checkpoint: %s", source) logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = download_from_original_stable_diffusion_ckpt( pipeline = download_from_original_stable_diffusion_ckpt(
source, source,
original_config_file=config_path, original_config_file=config_path,
pipeline_class=pipe_class, pipeline_class=pipe_class,
**pipe_args, **pipe_args,
).to(device, torch_dtype=dtype) ).to(device, torch_dtype=dtype)
elif hf: elif source.startswith(HuggingfaceClient.protocol):
logger.debug("downloading pretrained model from Huggingface hub: %s", source) hf_path = remove_prefix(source, HuggingfaceClient.protocol)
logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path)
pipeline = pipe_class.from_pretrained( pipeline = pipe_class.from_pretrained(
source, hf_path,
torch_dtype=dtype, torch_dtype=dtype,
use_auth_token=conversion.token, use_auth_token=conversion.token,
).to(device) ).to(device)
else: else:
logger.warning("pipeline source not found or not recognized: %s", source) logger.warning(
raise ValueError(f"pipeline source not found or not recognized: {source}") "pipeline source not found and protocol not recognized: %s", source
)
raise ValueError(
f"pipeline source not found and protocol not recognized: {source}"
)
if replace_vae is not None: if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae) vae_path = path.join(conversion.model_path, replace_vae)

View File

@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
from optimum.exporters.onnx import main_export from optimum.exporters.onnx import main_export
from ...constants import ONNX_MODEL from ...constants import ONNX_MODEL
from ..client import fetch_model
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -19,14 +20,13 @@ logger = getLogger(__name__)
def convert_diffusion_diffusers_xl( def convert_diffusion_diffusers_xl(
conversion: ConversionContext, conversion: ConversionContext,
model: Dict, model: Dict,
source: str,
format: Optional[str], format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
name = model.get("name") name = model.get("name")
source = model.get("source")
replace_vae = model.get("vae", None) replace_vae = model.get("vae", None)
device = conversion.training_device device = conversion.training_device
@ -52,15 +52,16 @@ def convert_diffusion_diffusers_xl(
return (False, dest_path) return (False, dest_path)
cache_path = fetch_model(conversion, name, model["source"], format=format)
# safetensors -> diffusers directory with torch models # safetensors -> diffusers directory with torch models
temp_path = path.join(conversion.cache_path, f"{name}-torch") temp_path = path.join(conversion.cache_path, f"{name}-torch")
if format == "safetensors": if format == "safetensors":
pipeline = StableDiffusionXLPipeline.from_single_file( pipeline = StableDiffusionXLPipeline.from_single_file(
source, use_safetensors=True cache_path, use_safetensors=True
) )
else: else:
pipeline = StableDiffusionXLPipeline.from_pretrained(source) pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)
if replace_vae is not None: if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae) vae_path = path.join(conversion.model_path, replace_vae)