fix(api): use Torch pipelines while loading models for conversion
This commit is contained in:
parent
c99481f484
commit
5d3a7d77a5
|
@ -16,14 +16,22 @@ from shutil import rmtree
|
|||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from onnx import load_model, save_model
|
||||
|
||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from ...diffusers.load import available_pipelines, optimize_pipeline
|
||||
from ...diffusers.load import optimize_pipeline
|
||||
from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||
from ...models.cnet import UNet2DConditionModel_CNet
|
||||
|
@ -33,6 +41,17 @@ from .checkpoint import convert_extract_checkpoint
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
CONVERT_PIPELINES = {
|
||||
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
||||
"img2img": StableDiffusionPipeline,
|
||||
"inpaint": StableDiffusionPipeline,
|
||||
"lpw": StableDiffusionPipeline,
|
||||
"panorama": StableDiffusionPipeline,
|
||||
"pix2pix": StableDiffusionInstructPix2PixPipeline,
|
||||
"txt2img": StableDiffusionPipeline,
|
||||
"upscale": StableDiffusionUpscalePipeline,
|
||||
}
|
||||
|
||||
|
||||
def get_model_version(
|
||||
source,
|
||||
|
@ -295,7 +314,7 @@ def convert_diffusion_diffusers(
|
|||
logger.info("ONNX model already exists, skipping")
|
||||
return (False, dest_path)
|
||||
|
||||
pipe_class = available_pipelines.get(pipe_type)
|
||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||
v2, pipe_args = get_model_version(
|
||||
source, conversion.map_location, size=image_size, version=version
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue