1
0
Fork 0

avoid waiting for final progress

This commit is contained in:
Sean Sube 2024-01-03 21:39:19 -06:00
parent 5a48447585
commit 9b5e894898
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 11 additions and 11 deletions

View File

@ -112,7 +112,7 @@ def run_txt2img_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, server, params, StageResult.empty(), callback=progress, latents=latents worker, server, params, StageResult.empty(), callback=progress, latents=latents
) )
@ -210,7 +210,7 @@ def run_img2img_pipeline(
) )
# run and append the filtered source # run and append the filtered source
progress = worker.get_progress_callback() progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, server, params, StageResult(images=[source]), callback=progress worker, server, params, StageResult(images=[source]), callback=progress
) )
@ -380,7 +380,7 @@ def run_inpaint_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, worker,
server, server,
@ -457,7 +457,7 @@ def run_upscale_pipeline(
) )
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, server, params, StageResult(images=[source]), callback=progress worker, server, params, StageResult(images=[source]), callback=progress
) )
@ -506,7 +506,7 @@ def run_blend_pipeline(
) )
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, server, params, StageResult(images=sources), callback=progress worker, server, params, StageResult(images=sources), callback=progress
) )

View File

@ -96,10 +96,10 @@ class WorkerContext:
return 0 return 0
def get_progress_callback(self) -> ProgressCallback: def get_progress_callback(self, reset=False) -> ProgressCallback:
from ..chain.pipeline import ChainProgress from ..chain.pipeline import ChainProgress
if self.callback is not None: if not reset and self.callback is not None:
return self.callback return self.callback
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
@ -157,10 +157,10 @@ class WorkerContext:
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.SUCCESS, JobStatus.SUCCESS,
steps=self.last_progress.steps, steps=self.callback.steps,
stages=self.last_progress.stages, stages=self.callback.stages,
tiles=self.last_progress.tiles, tiles=self.callback.tiles,
results=self.last_progress.results, results=self.callback.results,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,