From 9b5e89489832ff43671d6dbf45641747c70cfa3b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 3 Jan 2024 21:39:19 -0600 Subject: [PATCH] avoid waiting for final progress --- api/onnx_web/diffusers/run.py | 10 +++++----- api/onnx_web/worker/context.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 5bc625c4..98472eb9 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -112,7 +112,7 @@ def run_txt2img_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) - progress = worker.get_progress_callback() + progress = worker.get_progress_callback(reset=True) images = chain( worker, server, params, StageResult.empty(), callback=progress, latents=latents ) @@ -210,7 +210,7 @@ def run_img2img_pipeline( ) # run and append the filtered source - progress = worker.get_progress_callback() + progress = worker.get_progress_callback(reset=True) images = chain( worker, server, params, StageResult(images=[source]), callback=progress ) @@ -380,7 +380,7 @@ def run_inpaint_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) - progress = worker.get_progress_callback() + progress = worker.get_progress_callback(reset=True) images = chain( worker, server, @@ -457,7 +457,7 @@ def run_upscale_pipeline( ) # run and save - progress = worker.get_progress_callback() + progress = worker.get_progress_callback(reset=True) images = chain( worker, server, params, StageResult(images=[source]), callback=progress ) @@ -506,7 +506,7 @@ def run_blend_pipeline( ) # run and save - progress = worker.get_progress_callback() + progress = worker.get_progress_callback(reset=True) images = chain( worker, server, params, StageResult(images=sources), callback=progress ) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 2e832169..3ec4be47 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -96,10 +96,10 @@ class WorkerContext: return 0 - def get_progress_callback(self) -> ProgressCallback: + def get_progress_callback(self, reset=False) -> ProgressCallback: from ..chain.pipeline import ChainProgress - if self.callback is not None: + if not reset and self.callback is not None: return self.callback def on_progress(step: int, timestep: int, latents: Any): @@ -157,10 +157,10 @@ class WorkerContext: self.job_type, self.device.device, JobStatus.SUCCESS, - steps=self.last_progress.steps, - stages=self.last_progress.stages, - tiles=self.last_progress.tiles, - results=self.last_progress.results, + steps=self.callback.steps, + stages=self.callback.stages, + tiles=self.callback.tiles, + results=self.callback.results, ) self.progress.put( self.last_progress,