diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index a14f2c95..6fe9b172 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -16,14 +16,22 @@ from shutil import rmtree from typing import Any, Dict, Optional, Tuple, Union import torch -from diffusers import AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline +from diffusers import ( + AutoencoderKL, + OnnxRuntimeModel, + OnnxStableDiffusionPipeline, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionPipeline, + StableDiffusionUpscalePipeline, +) from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( download_from_original_stable_diffusion_ckpt, ) from onnx import load_model, save_model from ...constants import ONNX_MODEL, ONNX_WEIGHTS -from ...diffusers.load import available_pipelines, optimize_pipeline +from ...diffusers.load import optimize_pipeline +from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet @@ -33,6 +41,17 @@ from .checkpoint import convert_extract_checkpoint logger = getLogger(__name__) +CONVERT_PIPELINES = { + "controlnet": OnnxStableDiffusionControlNetPipeline, + "img2img": StableDiffusionPipeline, + "inpaint": StableDiffusionPipeline, + "lpw": StableDiffusionPipeline, + "panorama": StableDiffusionPipeline, + "pix2pix": StableDiffusionInstructPix2PixPipeline, + "txt2img": StableDiffusionPipeline, + "upscale": StableDiffusionUpscalePipeline, +} + def get_model_version( source, @@ -295,7 +314,7 @@ def convert_diffusion_diffusers( logger.info("ONNX model already exists, skipping") return (False, dest_path) - pipe_class = available_pipelines.get(pipe_type) + pipe_class = CONVERT_PIPELINES.get(pipe_type) v2, pipe_args = get_model_version( source, conversion.map_location, size=image_size, version=version )