Compare commits
3 Commits
1e73eac68d
...
87f34cf369
Author | SHA1 | Date |
---|---|---|
Sean Sube | 87f34cf369 | |
Sean Sube | bea97efeb0 | |
Sean Sube | 24039ab0ea |
|
@ -61,7 +61,7 @@ model_converters: Dict[str, Any] = {
|
|||
"archive": convert_extract_archive,
|
||||
"img2img": convert_diffusion_diffusers_optimum,
|
||||
"img2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
"inpaint": convert_diffusion_diffusers_legacy,
|
||||
"inpaint": convert_diffusion_diffusers_optimum,
|
||||
"txt2img": convert_diffusion_diffusers_optimum,
|
||||
"txt2img-legacy": convert_diffusion_diffusers_legacy,
|
||||
"txt2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
|
|
|
@ -20,6 +20,7 @@ from diffusers import (
|
|||
AutoencoderKL,
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
|
@ -48,13 +49,14 @@ from ..utils import (
|
|||
remove_prefix,
|
||||
)
|
||||
from .checkpoint import convert_extract_checkpoint
|
||||
from .patches import patch_optimum
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
CONVERT_PIPELINES = {
|
||||
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
||||
"img2img": StableDiffusionPipeline,
|
||||
"inpaint": StableDiffusionPipeline,
|
||||
"inpaint": StableDiffusionInpaintPipeline,
|
||||
"lpw": StableDiffusionPipeline,
|
||||
"panorama": StableDiffusionPipeline,
|
||||
"pix2pix": StableDiffusionInstructPix2PixPipeline,
|
||||
|
@ -841,6 +843,8 @@ def convert_diffusion_diffusers_optimum(
|
|||
del pipeline
|
||||
run_gc()
|
||||
|
||||
# patch Optimum for conversion and convert to ONNX
|
||||
patch_optimum()
|
||||
main_export(
|
||||
temp_path,
|
||||
output=dest_path,
|
||||
|
@ -850,6 +854,7 @@ def convert_diffusion_diffusers_optimum(
|
|||
"torch-fp16"
|
||||
), # optimum's fp16 mode only works on CUDA or ROCm
|
||||
framework="pt",
|
||||
library_name="diffusers",
|
||||
do_validation=conversion.has_feature("optimum-validation"),
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
"""
|
||||
Patches for optimum's internal conversion process.
|
||||
"""
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
from optimum.exporters.onnx import model_patcher
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
original_override_arguments = model_patcher.override_arguments
|
||||
|
||||
|
||||
def override_override_arguments(args, kwargs, signature, model_kwargs=None):
|
||||
"""
|
||||
Override the arguments of the `override_arguments` function.
|
||||
"""
|
||||
logger.debug(
|
||||
"overriding arguments for `override_arguments`: %s, %s, %s",
|
||||
args,
|
||||
kwargs,
|
||||
signature,
|
||||
)
|
||||
|
||||
if "output_hidden_states" in signature.parameters:
|
||||
logger.debug("enabling hidden states for model")
|
||||
parameter_names = list(signature.parameters.keys())
|
||||
hidden_states_index = parameter_names.index("output_hidden_states")
|
||||
|
||||
# convert the arguments to a list for modification
|
||||
arg_list = list(args)
|
||||
arg_list[hidden_states_index] = True
|
||||
args = tuple(arg_list)
|
||||
|
||||
return original_override_arguments(args, kwargs, signature, model_kwargs)
|
||||
|
||||
|
||||
def patch_optimum():
|
||||
logger.info("installing patches for optimum's internal conversion process")
|
||||
model_patcher.override_arguments = override_override_arguments
|
|
@ -76,6 +76,11 @@ def encode_prompt_compel(
|
|||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Text encoder patch for SD v1 and v2.
|
||||
|
||||
Using clip skip requires an ONNX model compiled with `return_hidden_states=True`.
|
||||
"""
|
||||
prompt, skip_clip_states = split_clip_skip(prompt)
|
||||
|
||||
embeddings_type = (
|
||||
|
|
Loading…
Reference in New Issue