1
0
Fork 0

Compare commits

...

2 Commits

3 changed files with 14 additions and 5 deletions

View File

@ -61,7 +61,7 @@ model_converters: Dict[str, Any] = {
"archive": convert_extract_archive,
"img2img": convert_diffusion_diffusers_optimum,
"img2img-sdxl": convert_diffusion_diffusers_xl,
"inpaint": convert_diffusion_diffusers_legacy,
"inpaint": convert_diffusion_diffusers_optimum,
"txt2img": convert_diffusion_diffusers_optimum,
"txt2img-legacy": convert_diffusion_diffusers_legacy,
"txt2img-sdxl": convert_diffusion_diffusers_xl,

View File

@ -20,6 +20,7 @@ from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
@ -55,7 +56,7 @@ logger = getLogger(__name__)
CONVERT_PIPELINES = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline,
"inpaint": StableDiffusionInpaintPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
@ -853,6 +854,7 @@ def convert_diffusion_diffusers_optimum(
"torch-fp16"
), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt",
library_name="diffusers",
do_validation=conversion.has_feature("optimum-validation"),
)

View File

@ -15,15 +15,22 @@ def override_override_arguments(args, kwargs, signature, model_kwargs=None):
"""
Override the arguments of the `override_arguments` function.
"""
logger.info(
logger.debug(
"overriding arguments for `override_arguments`: %s, %s, %s",
args,
kwargs,
signature,
)
# if "return_hidden_states" signature.parameters:
# args[4] = True
if "output_hidden_states" in signature.parameters:
logger.debug("enabling hidden states for model")
parameter_names = list(signature.parameters.keys())
hidden_states_index = parameter_names.index("output_hidden_states")
# convert the arguments to a list for modification
arg_list = list(args)
arg_list[hidden_states_index] = True
args = tuple(arg_list)
return original_override_arguments(args, kwargs, signature, model_kwargs)