1
0
Fork 0

add conversion patch for text encoder

This commit is contained in:
Sean Sube 2024-03-03 14:42:13 -06:00
parent 1e73eac68d
commit 24039ab0ea
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 41 additions and 0 deletions

View File

@ -48,6 +48,7 @@ from ..utils import (
remove_prefix,
)
from .checkpoint import convert_extract_checkpoint
from .patches import patch_optimum
logger = getLogger(__name__)
@ -841,6 +842,8 @@ def convert_diffusion_diffusers_optimum(
del pipeline
run_gc()
# patch Optimum for conversion and convert to ONNX
patch_optimum()
main_export(
temp_path,
output=dest_path,

View File

@ -0,0 +1,33 @@
"""
Patches for optimum's internal conversion process.
"""
from logging import getLogger
from optimum.exporters.onnx import model_patcher
logger = getLogger(__name__)
original_override_arguments = model_patcher.override_arguments
def override_override_arguments(args, kwargs, signature, model_kwargs=None):
"""
Override the arguments of the `override_arguments` function.
"""
logger.info(
"overriding arguments for `override_arguments`: %s, %s, %s",
args,
kwargs,
signature,
)
# if "return_hidden_states" signature.parameters:
# args[4] = True
return original_override_arguments(args, kwargs, signature, model_kwargs)
def patch_optimum():
logger.info("installing patches for optimum's internal conversion process")
model_patcher.override_arguments = override_override_arguments

View File

@ -76,6 +76,11 @@ def encode_prompt_compel(
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Text encoder patch for SD v1 and v2.
Using clip skip requires an ONNX model compiled with `return_hidden_states=True`.
"""
prompt, skip_clip_states = split_clip_skip(prompt)
embeddings_type = (