From 7b506cb6d335ee89c1c0c405ab4ac48ef5b1f769 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 4 Feb 2023 17:44:54 -0600 Subject: [PATCH] feat(api): enable LPW custom pipeline (#27) --- api/onnx_web/chain/upscale_outpaint.py | 6 +++--- api/onnx_web/diffusion/load.py | 2 ++ api/onnx_web/diffusion/run.py | 6 +++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index a61d2ab5..4f6ecbb5 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -75,14 +75,14 @@ def upscale_outpaint( latents = get_tile_latents(full_latents, dims) rng = np.random.RandomState(params.seed) - result = pipe( + result = pipe.inpaint( + image, + mask, prompt, generator=rng, guidance_scale=params.cfg, height=size.height, - image=image, latents=latents, - mask_image=mask, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 00c0bf2e..3a530bd0 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -73,6 +73,8 @@ def load_pipeline( model, provider=device.provider, provider_options=device.options, + custom_pipeline='lpw_stable_diffusion_onnx', + revision='onnx', safety_checker=None, scheduler=scheduler, ) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 00ad3479..5f40faa4 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -32,7 +32,7 @@ def run_txt2img_pipeline( rng = np.random.RandomState(params.seed) progress = job.get_progress_callback() - result = pipe( + result = pipe.txt2img( params.prompt, height=size.height, width=size.width, @@ -77,11 +77,11 @@ def run_img2img_pipeline( rng = np.random.RandomState(params.seed) progress = job.get_progress_callback() - result = pipe( + result = pipe.img2img( + source_image, params.prompt, generator=rng, guidance_scale=params.cfg, - image=source_image, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, strength=strength,