1
0
Fork 0

fix(api): use Torch pipelines while loading models for conversion

This commit is contained in:
Sean Sube 2023-09-24 10:04:21 -05:00
parent c99481f484
commit 5d3a7d77a5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 22 additions and 3 deletions

View File

@ -16,14 +16,22 @@ from shutil import rmtree
from typing import Any, Dict, Optional, Tuple, Union
import torch
from diffusers import AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline
from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import available_pipelines, optimize_pipeline
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
@ -33,6 +41,17 @@ from .checkpoint import convert_extract_checkpoint
logger = getLogger(__name__)
CONVERT_PIPELINES = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
"txt2img": StableDiffusionPipeline,
"upscale": StableDiffusionUpscalePipeline,
}
def get_model_version(
source,
@ -295,7 +314,7 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type)
pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
)