1
0
Fork 0

test controlnet pipe

This commit is contained in:
Sean Sube 2023-04-11 23:06:32 -05:00
parent 113f21fd34
commit f49a4ddfdf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 13 additions and 8 deletions

View File

@ -167,7 +167,9 @@ def convert_diffusion_diffusers(
del pipeline.unet del pipeline.unet
# CNet # CNet
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet") pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
device=device, dtype=dtype
)
cnet_path = output_path / "cnet" / ONNX_MODEL cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export( onnx_export(

View File

@ -122,6 +122,8 @@ def load_pipeline(
inversions = inversions or [] inversions = inversions or []
loras = loras or [] loras = loras or []
controlnet = "canny" # TODO; from params
torch_dtype = ( torch_dtype = (
torch.float16 if "torch-fp16" in server.optimizations else torch.float32 torch.float16 if "torch-fp16" in server.optimizations else torch.float32
) )
@ -276,6 +278,9 @@ def load_pipeline(
) )
) )
if controlnet is not None:
components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet)
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
custom_pipeline=custom_pipeline, custom_pipeline=custom_pipeline,

View File

@ -86,7 +86,6 @@ class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
controlnet=controlnet, controlnet=controlnet,
scheduler=scheduler, scheduler=scheduler,
) )
self.register_to_config(requires_safety_checker=False)
def _default_height_width(self, height, width, image): def _default_height_width(self, height, width, image):
if isinstance(image, list): if isinstance(image, list):
@ -486,6 +485,4 @@ class OnnxStableDiffusionControlNetPipeline(DiffusionPipeline):
if not return_dict: if not return_dict:
return (image, None) return (image, None)
return StableDiffusionPipelineOutput( return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
images=image, nsfw_content_detected=None
)

View File

@ -6,10 +6,10 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image from PIL import Image
from onnx_web.chain.utils import process_tile_order
from ..chain import blend_mask, upscale_outpaint from ..chain import blend_mask, upscale_outpaint
from ..chain.base import ChainProgress from ..chain.base import ChainProgress
from ..chain.utils import process_tile_order
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..output import save_image, save_params from ..output import save_image, save_params
from ..params import ( from ..params import (
Border, Border,
@ -233,7 +233,8 @@ def run_img2img_pipeline(
pipe = load_pipeline( pipe = load_pipeline(
server, server,
OnnxStableDiffusionImg2ImgPipeline, # OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionControlNetPipeline,
params.model, params.model,
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),