diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 6aa36071..b48e1e54 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -5,7 +5,7 @@ from typing import Any, List, Optional, Protocol, Tuple from PIL import Image -from ..device_pool import JobContext +from ..device_pool import JobContext, ProgressCallback from ..output import save_image from ..params import ImageParams, StageParams from ..utils import ServerContext, is_debug @@ -30,6 +30,24 @@ class StageCallback(Protocol): PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]] +class ChainProgress: + def __init__(self, parent: ProgressCallback, start=0) -> None: + self.parent = parent + self.step = start + self.total = 0 + + def __call__(self, step: int, timestep: int, latents: Any) -> None: + if step < self.step: + # accumulate on resets + self.total += self.step + + self.step = step + self.parent(self.get_total(), timestep, latents) + + def get_total(self) -> int: + return self.step + self.total + + class ChainPipeline: """ Run many stages in series, passing the image results from each to the next, and processing @@ -57,11 +75,15 @@ class ChainPipeline: server: ServerContext, params: ImageParams, source: Image.Image, + callback: ProgressCallback = None, **pipeline_kwargs ) -> Image.Image: """ - TODO: handle List[Image] outputs + TODO: handle List[Image] inputs and outputs """ + if callback is not None: + callback = ChainProgress(callback, start=callback.step) + start = monotonic() logger.info( "running pipeline on source image with dimensions %sx%s", diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index d66e6f83..43568216 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -6,7 +6,7 @@ import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline from PIL import Image -from ..device_pool import JobContext +from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams from ..utils import ServerContext @@ -23,6 +23,7 @@ def blend_img2img( *, strength: float, prompt: Optional[str] = None, + callback: ProgressCallback, **kwargs, ) -> Image.Image: prompt = prompt or params.prompt @@ -46,6 +47,7 @@ def blend_img2img( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, strength=strength, + callback=callback, ) else: rng = np.random.RandomState(params.seed) @@ -57,6 +59,7 @@ def blend_img2img( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, strength=strength, + callback=callback, ) output = result.images[0] diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 350bdfe0..9531f439 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -6,7 +6,7 @@ import torch from diffusers import OnnxStableDiffusionInpaintPipeline from PIL import Image -from ..device_pool import JobContext +from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import get_latents_from_seed, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image @@ -29,6 +29,7 @@ def blend_inpaint( fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, + callback: ProgressCallback, **kwargs, ) -> Image.Image: logger.info("upscaling image by expanding borders", expand) @@ -83,6 +84,7 @@ def blend_inpaint( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, + callback=callback, ) else: rng = np.random.RandomState(params.seed) @@ -97,6 +99,7 @@ def blend_inpaint( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, + callback=callback, ) return result.images[0] diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index ee584024..af32da0a 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -5,7 +5,7 @@ import torch from diffusers import OnnxStableDiffusionPipeline from PIL import Image -from ..device_pool import JobContext +from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams from ..utils import ServerContext @@ -22,6 +22,7 @@ def source_txt2img( *, size: Size, prompt: str = None, + callback: ProgressCallback = None, **kwargs, ) -> Image.Image: prompt = prompt or params.prompt @@ -53,6 +54,7 @@ def source_txt2img( latents=latents, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, + callback=callback, ) else: rng = np.random.RandomState(params.seed) @@ -65,6 +67,7 @@ def source_txt2img( latents=latents, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, + callback=callback, ) output = result.images[0] diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 621c4eb5..25e87173 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -6,7 +6,7 @@ import torch from diffusers import OnnxStableDiffusionInpaintPipeline from PIL import Image, ImageDraw -from ..device_pool import JobContext +from ..device_pool import JobContext, ProgressCallback from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image @@ -30,6 +30,7 @@ def upscale_outpaint( fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, + callback: ProgressCallback, **kwargs, ) -> Image.Image: prompt = prompt or params.prompt @@ -92,6 +93,7 @@ def upscale_outpaint( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, + callback=callback, ) else: rng = np.random.RandomState(params.seed) @@ -106,6 +108,7 @@ def upscale_outpaint( negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, + callback=callback, ) # once part of the image has been drawn, keep it diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 43614d1d..69654e94 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -5,7 +5,7 @@ import torch from diffusers import StableDiffusionUpscalePipeline from PIL import Image -from ..device_pool import JobContext +from ..device_pool import JobContext, ProgressCallback from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) @@ -67,6 +67,7 @@ def upscale_stable_diffusion( *, upscale: UpscaleParams, prompt: str = None, + callback: ProgressCallback, **kwargs, ) -> Image.Image: prompt = prompt or params.prompt @@ -80,4 +81,5 @@ def upscale_stable_diffusion( source, generator=generator, num_inference_steps=params.steps, + callback=callback, ).images[0] diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py index d52c5671..0d359ad7 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/device_pool.py @@ -10,6 +10,8 @@ from .utils import run_gc logger = getLogger(__name__) +ProgressCallback = Callable[[int, int, Any], None] + class JobContext: cancel: Value = None @@ -51,8 +53,9 @@ class JobContext: def get_progress(self) -> int: return self.progress.value - def get_progress_callback(self) -> Callable[..., None]: + def get_progress_callback(self) -> ProgressCallback: def on_progress(step: int, timestep: int, latents: Any): + on_progress.step = step if self.is_cancelled(): raise Exception("job has been cancelled") else: