1
0
Fork 0

Compare commits

...

3 Commits

4 changed files with 52 additions and 2 deletions

View File

@ -61,7 +61,7 @@ model_converters: Dict[str, Any] = {
"archive": convert_extract_archive,
"img2img": convert_diffusion_diffusers_optimum,
"img2img-sdxl": convert_diffusion_diffusers_xl,
"inpaint": convert_diffusion_diffusers_legacy,
"inpaint": convert_diffusion_diffusers_optimum,
"txt2img": convert_diffusion_diffusers_optimum,
"txt2img-legacy": convert_diffusion_diffusers_legacy,
"txt2img-sdxl": convert_diffusion_diffusers_xl,

View File

@ -20,6 +20,7 @@ from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
@ -48,13 +49,14 @@ from ..utils import (
remove_prefix,
)
from .checkpoint import convert_extract_checkpoint
from .patches import patch_optimum
logger = getLogger(__name__)
CONVERT_PIPELINES = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline,
"inpaint": StableDiffusionInpaintPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
@ -841,6 +843,8 @@ def convert_diffusion_diffusers_optimum(
del pipeline
run_gc()
# patch Optimum for conversion and convert to ONNX
patch_optimum()
main_export(
temp_path,
output=dest_path,
@ -850,6 +854,7 @@ def convert_diffusion_diffusers_optimum(
"torch-fp16"
), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt",
library_name="diffusers",
do_validation=conversion.has_feature("optimum-validation"),
)

View File

@ -0,0 +1,40 @@
"""
Patches for optimum's internal conversion process.
"""
from logging import getLogger
from optimum.exporters.onnx import model_patcher
logger = getLogger(__name__)
original_override_arguments = model_patcher.override_arguments
def override_override_arguments(args, kwargs, signature, model_kwargs=None):
"""
Override the arguments of the `override_arguments` function.
"""
logger.debug(
"overriding arguments for `override_arguments`: %s, %s, %s",
args,
kwargs,
signature,
)
if "output_hidden_states" in signature.parameters:
logger.debug("enabling hidden states for model")
parameter_names = list(signature.parameters.keys())
hidden_states_index = parameter_names.index("output_hidden_states")
# convert the arguments to a list for modification
arg_list = list(args)
arg_list[hidden_states_index] = True
args = tuple(arg_list)
return original_override_arguments(args, kwargs, signature, model_kwargs)
def patch_optimum():
logger.info("installing patches for optimum's internal conversion process")
model_patcher.override_arguments = override_override_arguments

View File

@ -76,6 +76,11 @@ def encode_prompt_compel(
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Text encoder patch for SD v1 and v2.
Using clip skip requires an ONNX model compiled with `return_hidden_states=True`.
"""
prompt, skip_clip_states = split_clip_skip(prompt)
embeddings_type = (