feat(api): allow converting inpaint models (#356)
This commit is contained in:
parent
fbeeae692b
commit
dbd9a186ae
|
@ -20,6 +20,7 @@ from diffusers import (
|
|||
AutoencoderKL,
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from onnx import load_model, save_model
|
||||
|
@ -183,6 +184,7 @@ def convert_diffusion_diffusers(
|
|||
source = source or model.get("source")
|
||||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
pipe_type = model.get("pipeline", "image")
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
|
@ -211,11 +213,21 @@ def convert_diffusion_diffusers(
|
|||
logger.info("ONNX model already exists, skipping")
|
||||
return (False, dest_path)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
if pipe_type == "image":
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
elif pipe_type == "inpaint":
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
else:
|
||||
raise ValueError(f"unknown pipeline type: {pipe_type}")
|
||||
|
||||
output_path = Path(dest_path)
|
||||
|
||||
optimize_pipeline(conversion, pipeline)
|
||||
|
|
|
@ -86,6 +86,12 @@ $defs:
|
|||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/lora_network"
|
||||
pipeline:
|
||||
type: string
|
||||
enum: [
|
||||
image,
|
||||
inpaint
|
||||
]
|
||||
vae:
|
||||
type: string
|
||||
|
||||
|
|
Loading…
Reference in New Issue