diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index d51aa590..e9469226 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -31,7 +31,10 @@ def blend_img2img( params.model, params.scheduler, job.get_device(), + params.lpw, ) + if params.lpw: + pipe = pipe.img2img rng = torch.manual_seed(params.seed) diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 7b23405a..c83f1ef6 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -64,7 +64,10 @@ def blend_inpaint( params.model, params.scheduler, job.get_device(), + params.lpw, ) + if params.lpw: + pipe = pipe.inpaint latents = get_latents_from_seed(params.seed, size) rng = torch.manual_seed(params.seed) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 9b2f8a0a..7a904a1e 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -32,8 +32,10 @@ def source_txt2img( ) pipe = load_pipeline( - OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device() + OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw ) + if params.lpw: + pipe = pipe.text2img latents = get_latents_from_seed(params.seed, size) rng = torch.manual_seed(params.seed) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 602a7438..d81efc91 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -71,11 +71,13 @@ def upscale_outpaint( params.scheduler, job.get_device(), ) + if params.lpw: + pipe = pipe.inpaint latents = get_tile_latents(full_latents, dims) rng = torch.manual_seed(params.seed) - result = pipe.inpaint( + result = pipe( image, mask, prompt, @@ -96,7 +98,7 @@ def upscale_outpaint( margin_y = float(max(border.top, border.bottom)) overlap = min(margin_x / source_image.width, margin_y / source_image.height) - if overlap > 0 and border.left == border.right and border.top == border.bottom: + if border.left == border.right and border.top == border.bottom: logger.debug("outpainting with an even border, using spiral tiling") output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint], overlap=overlap) else: diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 0f6cd6eb..6a1ad9e3 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -47,13 +47,13 @@ def get_tile_latents( def load_pipeline( - pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams + pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams, lpw: bool ): global last_pipeline_instance global last_pipeline_scheduler global last_pipeline_options - options = (pipeline, model, device.provider) + options = (pipeline, model, device.device, device.provider, lpw) if last_pipeline_instance is not None and last_pipeline_options == options: logger.debug("reusing existing diffusion pipeline") pipe = last_pipeline_instance @@ -63,6 +63,11 @@ def load_pipeline( last_pipeline_scheduler = None run_gc() + if lpw: + custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py" + else: + custom_pipeline = None + logger.debug("loading new diffusion pipeline from %s", model) scheduler = scheduler.from_pretrained( model, @@ -72,7 +77,7 @@ def load_pipeline( ) pipe = pipeline.from_pretrained( model, - custom_pipeline="./onnx_web/diffusion/lpw_stable_diffusion_onnx.py", + custom_pipeline=custom_pipeline, provider=device.provider, provider_options=device.options, revision="onnx", diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 961bb3c1..0d5b1f59 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -25,14 +25,16 @@ def run_txt2img_pipeline( upscale: UpscaleParams, ) -> None: pipe = load_pipeline( - OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device() + OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device(), params.lpw ) + if params.lpw: + pipe = pipe.text2img latents = get_latents_from_seed(params.seed, size) rng = torch.manual_seed(params.seed) progress = job.get_progress_callback() - result = pipe.text2img( + result = pipe( params.prompt, height=size.height, width=size.width, @@ -72,12 +74,15 @@ def run_img2img_pipeline( params.model, params.scheduler, job.get_device(), + params.lpw ) + if params.lpw: + pipe = pipe.img2img rng = torch.manual_seed(params.seed) progress = job.get_progress_callback() - result = pipe.img2img( + result = pipe( source_image, params.prompt, generator=rng, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 26a6363a..525fc292 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -86,10 +86,11 @@ class ImageParams: model: str, scheduler: Any, prompt: str, - negative_prompt: Optional[str], cfg: float, steps: int, seed: int, + negative_prompt: Optional[str] = None, + lpw: Optional[bool] = False, ) -> None: self.model = model self.scheduler = scheduler @@ -98,6 +99,7 @@ class ImageParams: self.cfg = cfg self.seed = seed self.steps = steps + self.lpw = lpw or False def tojson(self) -> Dict[str, Optional[Param]]: return { @@ -108,6 +110,7 @@ class ImageParams: "cfg": self.cfg, "seed": self.seed, "steps": self.steps, + "lpw": self.lpw, } diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 0fb3f200..16a0c1ac 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -171,6 +171,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: device = available_platforms[0] # pipeline stuff + lpw = get_not_empty(request.args, "lpw", "false") == "true" model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(model) scheduler = get_from_map( @@ -233,7 +234,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: ) params = ImageParams( - model_path, scheduler, prompt, negative_prompt, cfg, steps, seed + model_path, scheduler, prompt, negative_prompt, cfg, steps, seed, lpw=lpw ) size = Size(width, height) return (device, params, size)