1
0
Fork 0

feat(api): allow converting inpaint models (#356)

This commit is contained in:
Sean Sube 2023-04-27 23:41:04 -05:00
parent fbeeae692b
commit dbd9a186ae
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 23 additions and 5 deletions

View File

@ -20,6 +20,7 @@ from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
from onnx import load_model, save_model
@ -183,6 +184,7 @@ def convert_diffusion_diffusers(
source = source or model.get("source")
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
pipe_type = model.get("pipeline", "image")
device = conversion.training_device
dtype = conversion.torch_dtype()
@ -211,11 +213,21 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
pipeline = StableDiffusionPipeline.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
if pipe_type == "image":
pipeline = StableDiffusionPipeline.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
elif pipe_type == "inpaint":
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
else:
raise ValueError(f"unknown pipeline type: {pipe_type}")
output_path = Path(dest_path)
optimize_pipeline(conversion, pipeline)

View File

@ -86,6 +86,12 @@ $defs:
type: array
items:
$ref: "#/$defs/lora_network"
pipeline:
type: string
enum: [
image,
inpaint
]
vae:
type: string