report stage and tile count in progress
This commit is contained in:
parent
9d05d9baac
commit
0c504e3f69
|
@ -22,6 +22,13 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||||
|
|
||||||
|
|
||||||
class ChainProgress:
|
class ChainProgress:
|
||||||
|
parent: ProgressCallback
|
||||||
|
step: int
|
||||||
|
total: int
|
||||||
|
stage: int
|
||||||
|
tile: int
|
||||||
|
# TODO: total stages and tiles
|
||||||
|
|
||||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.step = start
|
self.step = start
|
||||||
|
@ -94,12 +101,8 @@ class ChainPipeline:
|
||||||
|
|
||||||
return steps
|
return steps
|
||||||
|
|
||||||
def outputs(self, params: ImageParams, sources: int) -> int:
|
def stages(self) -> int:
|
||||||
outputs = sources
|
return len(self.stages)
|
||||||
for callback, _params, kwargs in self.stages:
|
|
||||||
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -110,12 +113,11 @@ class ChainPipeline:
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**pipeline_kwargs,
|
**pipeline_kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
"""
|
|
||||||
DEPRECATED: use `.run()` instead
|
|
||||||
"""
|
|
||||||
if callback is None:
|
if callback is None:
|
||||||
callback = worker.get_progress_callback()
|
callback = worker.get_progress_callback()
|
||||||
else:
|
|
||||||
|
# wrap the progress counter in a one that can be reset if needed
|
||||||
|
if not isinstance(callback, ChainProgress):
|
||||||
callback = ChainProgress.from_progress(callback)
|
callback = ChainProgress.from_progress(callback)
|
||||||
|
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
|
@ -129,7 +131,7 @@ class ChainPipeline:
|
||||||
logger.info("running pipeline without source images")
|
logger.info("running pipeline without source images")
|
||||||
|
|
||||||
stage_sources = sources
|
stage_sources = sources
|
||||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
for stage_i, (stage_pipe, stage_params, stage_kwargs) in enumerate(self.stages):
|
||||||
name = stage_params.name or stage_pipe.__class__.__name__
|
name = stage_params.name or stage_pipe.__class__.__name__
|
||||||
kwargs = stage_kwargs or {}
|
kwargs = stage_kwargs or {}
|
||||||
kwargs = {**pipeline_kwargs, **kwargs}
|
kwargs = {**pipeline_kwargs, **kwargs}
|
||||||
|
@ -139,6 +141,7 @@ class ChainPipeline:
|
||||||
len(stage_sources),
|
len(stage_sources),
|
||||||
kwargs.keys(),
|
kwargs.keys(),
|
||||||
)
|
)
|
||||||
|
callback.stage = stage_i
|
||||||
|
|
||||||
per_stage_params = params
|
per_stage_params = params
|
||||||
if "params" in kwargs:
|
if "params" in kwargs:
|
||||||
|
@ -168,6 +171,7 @@ class ChainPipeline:
|
||||||
tile,
|
tile,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
callback.tile = 0
|
||||||
def stage_tile(
|
def stage_tile(
|
||||||
source_tile: List[Image.Image],
|
source_tile: List[Image.Image],
|
||||||
tile_mask: Image.Image,
|
tile_mask: Image.Image,
|
||||||
|
@ -191,6 +195,8 @@ class ChainPipeline:
|
||||||
for j, image in enumerate(tile_result.as_image()):
|
for j, image in enumerate(tile_result.as_image()):
|
||||||
save_image(server, f"last-tile-{j}.png", image)
|
save_image(server, f"last-tile-{j}.png", image)
|
||||||
|
|
||||||
|
callback.tile = callback.tile + 1
|
||||||
|
|
||||||
return tile_result
|
return tile_result
|
||||||
except CancelledException as err:
|
except CancelledException as err:
|
||||||
worker.retries = 0
|
worker.retries = 0
|
||||||
|
|
|
@ -28,6 +28,7 @@ class WorkerContext:
|
||||||
timeout: float
|
timeout: float
|
||||||
retries: int
|
retries: int
|
||||||
initial_retries: int
|
initial_retries: int
|
||||||
|
callback: Optional[Any]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -97,11 +98,15 @@ class WorkerContext:
|
||||||
def get_progress_callback(self) -> ProgressCallback:
|
def get_progress_callback(self) -> ProgressCallback:
|
||||||
from ..chain.pipeline import ChainProgress
|
from ..chain.pipeline import ChainProgress
|
||||||
|
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
if self.callback is not None:
|
||||||
on_progress.step = step
|
return self.callback
|
||||||
self.set_progress(step)
|
|
||||||
|
|
||||||
return ChainProgress.from_progress(on_progress)
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
|
self.callback.step = step
|
||||||
|
self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile)
|
||||||
|
|
||||||
|
self.callback = ChainProgress.from_progress(on_progress)
|
||||||
|
return self.callback
|
||||||
|
|
||||||
def set_cancel(self, cancel: bool = True) -> None:
|
def set_cancel(self, cancel: bool = True) -> None:
|
||||||
with self.cancel.get_lock():
|
with self.cancel.get_lock():
|
||||||
|
@ -111,20 +116,22 @@ class WorkerContext:
|
||||||
with self.idle.get_lock():
|
with self.idle.get_lock():
|
||||||
self.idle.value = idle
|
self.idle.value = idle
|
||||||
|
|
||||||
def set_progress(self, progress: int) -> None:
|
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
||||||
if self.job is None:
|
if self.job is None:
|
||||||
raise RuntimeError("no job on which to set progress")
|
raise RuntimeError("no job on which to set progress")
|
||||||
|
|
||||||
if self.is_cancelled():
|
if self.is_cancelled():
|
||||||
raise CancelledException("job has been cancelled")
|
raise CancelledException("job has been cancelled")
|
||||||
|
|
||||||
logger.debug("setting progress for job %s to %s", self.job, progress)
|
logger.debug("setting progress for job %s to %s", self.job, steps)
|
||||||
self.last_progress = ProgressCommand(
|
self.last_progress = ProgressCommand(
|
||||||
self.job,
|
self.job,
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.RUNNING,
|
JobStatus.RUNNING,
|
||||||
steps=progress,
|
steps=steps,
|
||||||
|
stages=stages,
|
||||||
|
tiles=tiles,
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
|
Loading…
Reference in New Issue