From 2f6a3afddb152727bee2a91a0359f05c331fe194 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Feb 2023 12:33:36 -0600 Subject: [PATCH] pass progress on to most stages --- api/onnx_web/chain/base.py | 23 ++++++++++++++++++++--- api/onnx_web/diffusion/load.py | 1 + api/onnx_web/diffusion/run.py | 11 +++++++---- api/onnx_web/upscale.py | 17 +++++++++++++---- 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index b48e1e54..1ad6a1de 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -82,7 +82,8 @@ class ChainPipeline: TODO: handle List[Image] inputs and outputs """ if callback is not None: - callback = ChainProgress(callback, start=callback.step) + start = callback.step if hasattr(callback, "step") else 0 + callback = ChainProgress(callback, start=start) start = monotonic() logger.info( @@ -115,7 +116,15 @@ class ChainPipeline: ) def stage_tile(tile: Image.Image, _dims) -> Image.Image: - tile = stage_pipe(job, server, stage_params, params, tile, **kwargs) + tile = stage_pipe( + job, + server, + stage_params, + params, + tile, + callback=callback, + **kwargs + ) if is_debug(): save_image(server, "last-tile.png", tile) @@ -131,7 +140,15 @@ class ChainPipeline: ) else: logger.info("image within tile size, running stage") - image = stage_pipe(job, server, stage_params, params, image, **kwargs) + image = stage_pipe( + job, + server, + stage_params, + params, + image, + callback=callback, + **kwargs + ) logger.info( "finished stage %s, result size: %sx%s", name, image.width, image.height diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 476d56e1..eb9d5e14 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -51,6 +51,7 @@ pipeline_schedulers = { "pndm": PNDMScheduler, } + def get_scheduler_name(scheduler: Any) -> Optional[str]: for k, v in pipeline_schedulers.items(): if scheduler == v or scheduler == v.__name__: diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 1c96bfd7..6c69eabe 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -154,7 +154,7 @@ def run_inpaint_pipeline( tile_order: str, ) -> None: # device = job.get_device() - # progress = job.get_progress_callback() + progress = job.get_progress_callback() stage = StageParams(tile_order=tile_order) image = upscale_outpaint( @@ -168,6 +168,7 @@ def run_inpaint_pipeline( fill_color=fill_color, mask_filter=mask_filter, noise_source=noise_source, + callback=progress, ) logger.info("applying mask filter and generating noise source") @@ -176,7 +177,9 @@ def run_inpaint_pipeline( else: logger.info("output image size does not match source, skipping post-blend") - image = run_upscale_correction(job, server, stage, params, image, upscale=upscale) + image = run_upscale_correction( + job, server, stage, params, image, upscale=upscale, callback=progress + ) dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale, border=border) @@ -197,11 +200,11 @@ def run_upscale_pipeline( source_image: Image.Image, ) -> None: # device = job.get_device() - # progress = job.get_progress_callback() + progress = job.get_progress_callback() stage = StageParams() image = run_upscale_correction( - job, server, stage, params, source_image, upscale=upscale + job, server, stage, params, source_image, upscale=upscale, callback=progress ) dest = save_image(server, output, image) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index f28724b6..d13ba7fd 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -9,7 +9,7 @@ from .chain import ( upscale_resrgan, upscale_stable_diffusion, ) -from .device_pool import JobContext +from .device_pool import JobContext, ProgressCallback from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .utils import ServerContext @@ -24,6 +24,7 @@ def run_upscale_correction( image: Image.Image, *, upscale: UpscaleParams, + callback: ProgressCallback = None, ) -> Image.Image: """ This is a convenience method for a chain pipeline that will run upscaling and @@ -35,10 +36,10 @@ def run_upscale_correction( if upscale.scale > 1: if "esrgan" in upscale.upscale_model: - resr_stage = StageParams( + esrgan_stage = StageParams( tile_size=stage.tile_size, outscale=upscale.outscale ) - chain.append((upscale_resrgan, resr_stage, None)) + chain.append((upscale_resrgan, esrgan_stage, None)) elif "stable-diffusion" in upscale.upscale_model: mini_tile = min(SizeChart.mini, stage.tile_size) sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) @@ -57,4 +58,12 @@ def run_upscale_correction( else: logger.warn("unknown correction model: %s", upscale.correction_model) - return chain(job, server, params, image, prompt=params.prompt, upscale=upscale) + return chain( + job, + server, + params, + image, + prompt=params.prompt, + upscale=upscale, + callback=callback, + )