test controlnet pipe
This commit is contained in:
parent
113f21fd34
commit
f49a4ddfdf
|
@ -167,7 +167,9 @@ def convert_diffusion_diffusers(
|
||||||
del pipeline.unet
|
del pipeline.unet
|
||||||
|
|
||||||
# CNet
|
# CNet
|
||||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
|
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
|
||||||
|
device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
cnet_path = output_path / "cnet" / ONNX_MODEL
|
cnet_path = output_path / "cnet" / ONNX_MODEL
|
||||||
onnx_export(
|
onnx_export(
|
||||||
|
|
|
@ -122,6 +122,8 @@ def load_pipeline(
|
||||||
inversions = inversions or []
|
inversions = inversions or []
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
|
|
||||||
|
controlnet = "canny" # TODO; from params
|
||||||
|
|
||||||
torch_dtype = (
|
torch_dtype = (
|
||||||
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||||
)
|
)
|
||||||
|
@ -276,6 +278,9 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if controlnet is not None:
|
||||||
|
components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet)
|
||||||
|
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
custom_pipeline=custom_pipeline,
|
custom_pipeline=custom_pipeline,
|
||||||
|
|
|
@ -86,7 +86,6 @@ class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
self.register_to_config(requires_safety_checker=False)
|
|
||||||
|
|
||||||
def _default_height_width(self, height, width, image):
|
def _default_height_width(self, height, width, image):
|
||||||
if isinstance(image, list):
|
if isinstance(image, list):
|
||||||
|
@ -486,6 +485,4 @@ class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (image, None)
|
return (image, None)
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
|
||||||
images=image, nsfw_content_detected=None
|
|
||||||
)
|
|
||||||
|
|
|
@ -6,10 +6,10 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.utils import process_tile_order
|
|
||||||
|
|
||||||
from ..chain import blend_mask, upscale_outpaint
|
from ..chain import blend_mask, upscale_outpaint
|
||||||
from ..chain.base import ChainProgress
|
from ..chain.base import ChainProgress
|
||||||
|
from ..chain.utils import process_tile_order
|
||||||
|
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||||
from ..output import save_image, save_params
|
from ..output import save_image, save_params
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
|
@ -233,7 +233,8 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
# OnnxStableDiffusionImg2ImgPipeline,
|
||||||
|
OnnxStableDiffusionControlNetPipeline,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
|
|
Loading…
Reference in New Issue