From 034be3266eb4cefcdb513a92677884d5286a63d9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Feb 2023 13:16:17 -0600 Subject: [PATCH] fix(api): accumulate progress from inpaint pipelines (#90) --- api/logging.yaml | 8 ++++---- api/onnx_web/chain/base.py | 8 ++++++-- api/onnx_web/diffusion/run.py | 21 +++++++++++++++++++-- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/api/logging.yaml b/api/logging.yaml index 8d345cec..24bd3c29 100644 --- a/api/logging.yaml +++ b/api/logging.yaml @@ -5,14 +5,14 @@ formatters: handlers: console: class: logging.StreamHandler - level: INFO + level: DEBUG formatter: simple stream: ext://sys.stdout loggers: '': - level: INFO + level: DEBUG handlers: [console] propagate: True root: - level: INFO - handlers: [console] \ No newline at end of file + level: DEBUG + handlers: [console] diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 1ad6a1de..dd661f30 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -47,6 +47,11 @@ class ChainProgress: def get_total(self) -> int: return self.step + self.total + @classmethod + def from_progress(cls, parent: ProgressCallback): + start = parent.step if hasattr(parent, "step") else 0 + return ChainProgress(parent, start=start) + class ChainPipeline: """ @@ -82,8 +87,7 @@ class ChainPipeline: TODO: handle List[Image] inputs and outputs """ if callback is not None: - start = callback.step if hasattr(callback, "step") else 0 - callback = ChainProgress(callback, start=start) + callback = ChainProgress.from_progress(callback) start = monotonic() logger.info( diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 225bd3e5..2b776b9c 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -6,6 +6,8 @@ import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from PIL import Image, ImageChops +from onnx_web.chain.base import ChainProgress + from ..chain import upscale_outpaint from ..device_pool import JobContext from ..output import save_image, save_params @@ -65,7 +67,13 @@ def run_txt2img_pipeline( image = result.images[0] image = run_upscale_correction( - job, server, StageParams(), params, image, upscale=upscale, callback=progress, + job, + server, + StageParams(), + params, + image, + upscale=upscale, + callback=progress, ) dest = save_image(server, output, image) @@ -123,7 +131,13 @@ def run_img2img_pipeline( image = result.images[0] image = run_upscale_correction( - job, server, StageParams(), params, image, upscale=upscale, callback=progress, + job, + server, + StageParams(), + params, + image, + upscale=upscale, + callback=progress, ) dest = save_image(server, output, image) @@ -157,6 +171,9 @@ def run_inpaint_pipeline( progress = job.get_progress_callback() stage = StageParams(tile_order=tile_order) + # calling the upscale_outpaint stage directly needs accumulating progress + progress = ChainProgress.from_progress(progress) + image = upscale_outpaint( job, server,