add conversion patch for text encoder
This commit is contained in:
parent
1e73eac68d
commit
24039ab0ea
|
@ -48,6 +48,7 @@ from ..utils import (
|
|||
remove_prefix,
|
||||
)
|
||||
from .checkpoint import convert_extract_checkpoint
|
||||
from .patches import patch_optimum
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -841,6 +842,8 @@ def convert_diffusion_diffusers_optimum(
|
|||
del pipeline
|
||||
run_gc()
|
||||
|
||||
# patch Optimum for conversion and convert to ONNX
|
||||
patch_optimum()
|
||||
main_export(
|
||||
temp_path,
|
||||
output=dest_path,
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
"""
|
||||
Patches for optimum's internal conversion process.
|
||||
"""
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
from optimum.exporters.onnx import model_patcher
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
original_override_arguments = model_patcher.override_arguments
|
||||
|
||||
|
||||
def override_override_arguments(args, kwargs, signature, model_kwargs=None):
|
||||
"""
|
||||
Override the arguments of the `override_arguments` function.
|
||||
"""
|
||||
logger.info(
|
||||
"overriding arguments for `override_arguments`: %s, %s, %s",
|
||||
args,
|
||||
kwargs,
|
||||
signature,
|
||||
)
|
||||
|
||||
# if "return_hidden_states" signature.parameters:
|
||||
# args[4] = True
|
||||
|
||||
return original_override_arguments(args, kwargs, signature, model_kwargs)
|
||||
|
||||
|
||||
def patch_optimum():
|
||||
logger.info("installing patches for optimum's internal conversion process")
|
||||
model_patcher.override_arguments = override_override_arguments
|
|
@ -76,6 +76,11 @@ def encode_prompt_compel(
|
|||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Text encoder patch for SD v1 and v2.
|
||||
|
||||
Using clip skip requires an ONNX model compiled with `return_hidden_states=True`.
|
||||
"""
|
||||
prompt, skip_clip_states = split_clip_skip(prompt)
|
||||
|
||||
embeddings_type = (
|
||||
|
|
Loading…
Reference in New Issue