test controlnet pipe
This commit is contained in:
parent
113f21fd34
commit
f49a4ddfdf
|
@ -167,7 +167,9 @@ def convert_diffusion_diffusers(
|
|||
del pipeline.unet
|
||||
|
||||
# CNet
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
|
||||
cnet_path = output_path / "cnet" / ONNX_MODEL
|
||||
onnx_export(
|
||||
|
|
|
@ -122,6 +122,8 @@ def load_pipeline(
|
|||
inversions = inversions or []
|
||||
loras = loras or []
|
||||
|
||||
controlnet = "canny" # TODO; from params
|
||||
|
||||
torch_dtype = (
|
||||
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||
)
|
||||
|
@ -276,6 +278,9 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
if controlnet is not None:
|
||||
components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet)
|
||||
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
custom_pipeline=custom_pipeline,
|
||||
|
|
|
@ -86,7 +86,6 @@ class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=False)
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
if isinstance(image, list):
|
||||
|
@ -486,6 +485,4 @@ class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||
if not return_dict:
|
||||
return (image, None)
|
||||
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=image, nsfw_content_detected=None
|
||||
)
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
|
||||
|
|
|
@ -6,10 +6,10 @@ import torch
|
|||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.utils import process_tile_order
|
||||
|
||||
from ..chain import blend_mask, upscale_outpaint
|
||||
from ..chain.base import ChainProgress
|
||||
from ..chain.utils import process_tile_order
|
||||
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from ..output import save_image, save_params
|
||||
from ..params import (
|
||||
Border,
|
||||
|
@ -233,7 +233,8 @@ def run_img2img_pipeline(
|
|||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
# OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionControlNetPipeline,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
|
|
Loading…
Reference in New Issue