Compare commits
2 Commits
24039ab0ea
...
87f34cf369
Author | SHA1 | Date |
---|---|---|
Sean Sube | 87f34cf369 | |
Sean Sube | bea97efeb0 |
|
@ -61,7 +61,7 @@ model_converters: Dict[str, Any] = {
|
|||
"archive": convert_extract_archive,
|
||||
"img2img": convert_diffusion_diffusers_optimum,
|
||||
"img2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
"inpaint": convert_diffusion_diffusers_legacy,
|
||||
"inpaint": convert_diffusion_diffusers_optimum,
|
||||
"txt2img": convert_diffusion_diffusers_optimum,
|
||||
"txt2img-legacy": convert_diffusion_diffusers_legacy,
|
||||
"txt2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
|
|
|
@ -20,6 +20,7 @@ from diffusers import (
|
|||
AutoencoderKL,
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
|
@ -55,7 +56,7 @@ logger = getLogger(__name__)
|
|||
CONVERT_PIPELINES = {
|
||||
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
||||
"img2img": StableDiffusionPipeline,
|
||||
"inpaint": StableDiffusionPipeline,
|
||||
"inpaint": StableDiffusionInpaintPipeline,
|
||||
"lpw": StableDiffusionPipeline,
|
||||
"panorama": StableDiffusionPipeline,
|
||||
"pix2pix": StableDiffusionInstructPix2PixPipeline,
|
||||
|
@ -853,6 +854,7 @@ def convert_diffusion_diffusers_optimum(
|
|||
"torch-fp16"
|
||||
), # optimum's fp16 mode only works on CUDA or ROCm
|
||||
framework="pt",
|
||||
library_name="diffusers",
|
||||
do_validation=conversion.has_feature("optimum-validation"),
|
||||
)
|
||||
|
||||
|
|
|
@ -15,15 +15,22 @@ def override_override_arguments(args, kwargs, signature, model_kwargs=None):
|
|||
"""
|
||||
Override the arguments of the `override_arguments` function.
|
||||
"""
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"overriding arguments for `override_arguments`: %s, %s, %s",
|
||||
args,
|
||||
kwargs,
|
||||
signature,
|
||||
)
|
||||
|
||||
# if "return_hidden_states" signature.parameters:
|
||||
# args[4] = True
|
||||
if "output_hidden_states" in signature.parameters:
|
||||
logger.debug("enabling hidden states for model")
|
||||
parameter_names = list(signature.parameters.keys())
|
||||
hidden_states_index = parameter_names.index("output_hidden_states")
|
||||
|
||||
# convert the arguments to a list for modification
|
||||
arg_list = list(args)
|
||||
arg_list[hidden_states_index] = True
|
||||
args = tuple(arg_list)
|
||||
|
||||
return original_override_arguments(args, kwargs, signature, model_kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue