From b6da935be6da2583e496acd4636b7dd8148deb33 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 4 Jan 2024 19:09:52 -0600 Subject: [PATCH] use progress type in command --- api/onnx_web/chain/pipeline.py | 57 ++++++++++++++++++++++++---------- api/onnx_web/diffusers/run.py | 2 +- api/onnx_web/params.py | 18 ----------- api/onnx_web/worker/command.py | 36 ++++++++++++++++----- api/onnx_web/worker/context.py | 40 +++++++++++++++++------- 5 files changed, 99 insertions(+), 54 deletions(-) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index 6c9451c1..a5f19830 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -5,6 +5,8 @@ from typing import Any, List, Optional, Tuple from PIL import Image +from ..worker.command import Progress + from ..errors import CancelledException, RetryException from ..output import save_image from ..params import ImageParams, Size, StageParams @@ -23,31 +25,43 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] class ChainProgress: parent: ProgressCallback - step: int - total: int - stage: int - tile: int + step: int # same as steps.current, left for legacy purposes + prev: int # accumulator when step resets + + # new progress trackers + steps: Progress + stages: Progress + tiles: Progress result: Optional[StageResult] - # TODO: total stages and tiles def __init__(self, parent: ProgressCallback, start=0) -> None: self.parent = parent self.step = start - self.total = 0 - self.stage = 0 - self.tile = 0 + self.prev = 0 + self.steps = Progress(self.step, self.prev) + self.stages = Progress(0, 0) + self.tiles = Progress(0, 0) self.result = None def __call__(self, step: int, timestep: int, latents: Any) -> None: if step < self.step: # accumulate on resets - self.total += self.step + self.prev += self.step self.step = step - self.parent(self.get_total(), timestep, latents) + + total = self.get_total() + self.steps.current = total + self.parent(total, timestep, latents) def get_total(self) -> int: - return self.step + self.total + return self.step + self.prev + + def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None: + self.prev = steps + self.steps.total = steps + self.stages.total = stages + self.tiles.total = tiles @classmethod def from_progress(cls, parent: ProgressCallback): @@ -61,6 +75,8 @@ class ChainPipeline: tiles as needed. """ + stages: List[PipelineStage] + def __init__( self, stages: Optional[List[PipelineStage]] = None, @@ -124,6 +140,11 @@ class ChainPipeline: if not isinstance(callback, ChainProgress): callback = ChainProgress.from_progress(callback) + # set estimated totals + callback.set_total( + self.steps(params, sources.size), stages=len(self.stages), tiles=0 + ) + start = monotonic() if len(sources) > 0: @@ -145,7 +166,7 @@ class ChainPipeline: len(stage_sources), kwargs.keys(), ) - callback.stage = stage_i + callback.stages.current = stage_i per_stage_params = params if "params" in kwargs: @@ -169,7 +190,7 @@ class ChainPipeline: if stage_pipe.max_tile > 0: tile = min(stage_pipe.max_tile, stage_params.tile_size) - callback.tile = 0 # reset this either way + callback.tiles.current = 0 # reset this either way if must_tile: logger.info( "image contains sources or is larger than tile size of %s, tiling stage", @@ -188,7 +209,9 @@ class ChainPipeline: server, stage_params, per_stage_params, - StageResult(images=source_tile, metadata=stage_sources.metadata), + StageResult( + images=source_tile, metadata=stage_sources.metadata + ), tile_mask=tile_mask, callback=callback, dims=dims, @@ -199,7 +222,7 @@ class ChainPipeline: for j, image in enumerate(tile_result.as_image()): save_image(server, f"last-tile-{j}.png", image) - callback.tile = callback.tile + 1 + callback.tiles.current = callback.tiles.current + 1 return tile_result except CancelledException as err: @@ -226,7 +249,9 @@ class ChainPipeline: **kwargs, ) - stage_sources = StageResult(images=stage_results) + stage_sources = StageResult( + images=stage_results, metadata=stage_sources.metadata + ) else: logger.debug( "image does not contain sources and is within tile size of %s, running stage", diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 98472eb9..3ed614a3 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -125,7 +125,7 @@ def run_txt2img_pipeline( thumbnail = cover.copy() thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) - images.insert_image(0, thumbnail) + images.insert_image(0, thumbnail, images.metadata[0]) save_result(server, images, worker.job) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 894d04f2..b8dc3049 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -13,24 +13,6 @@ Param = Union[str, int, float] Point = Tuple[int, int] -class Progress: - current: int - total: int - - def __init__(self, current: int, total: int) -> None: - self.current = current - self.total = total - - def __str__(self) -> str: - return "%s/%s" % (self.current, self.total) - - def tojson(self): - return { - "current": self.current, - "total": self.total, - } - - class SizeChart(IntEnum): micro = 64 mini = 128 # small tile for very expensive models diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index bc5587a5..31f38ee0 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -21,15 +21,33 @@ class JobType(str, Enum): CHAIN = "chain" +class Progress: + current: int + total: int + + def __init__(self, current: int, total: int) -> None: + self.current = current + self.total = total + + def __str__(self) -> str: + return "%s/%s" % (self.current, self.total) + + def tojson(self): + return { + "current": self.current, + "total": self.total, + } + + class ProgressCommand: device: str job: str job_type: str status: JobStatus - result: Any - steps: int - stages: int - tiles: int + result: Any # really StageResult but that would be a very circular import + steps: Progress + stages: Progress + tiles: Progress def __init__( self, @@ -37,19 +55,21 @@ class ProgressCommand: job_type: str, device: str, status: JobStatus, + steps: Progress, + stages: Progress, + tiles: Progress, result: Any = None, - steps: int = 0, - stages: int = 0, - tiles: int = 0, ): self.job = job self.job_type = job_type self.device = device self.status = status - self.result = result + + # progress info self.steps = steps self.stages = stages self.tiles = tiles + self.result = result class JobCommand: diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index a052778a..7e83a604 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -7,7 +7,7 @@ from torch.multiprocessing import Queue, Value from ..errors import CancelledException from ..params import DeviceParams -from .command import JobCommand, JobStatus, ProgressCommand +from .command import JobCommand, JobStatus, Progress, ProgressCommand logger = getLogger(__name__) @@ -90,11 +90,26 @@ class WorkerContext: """ return self.device - def get_progress(self) -> int: + def get_progress(self) -> Progress: + return self.get_last_steps() + + def get_last_steps(self) -> Progress: if self.last_progress is not None: return self.last_progress.steps - return 0 + return Progress(0, 0) + + def get_last_stages(self) -> Progress: + if self.last_progress is not None: + return self.last_progress.stages + + return Progress(0, 0) + + def get_last_tiles(self) -> Progress: + if self.last_progress is not None: + return self.last_progress.tiles + + return Progress(0, 0) def get_progress_callback(self, reset=False) -> ProgressCallback: from ..chain.pipeline import ChainProgress @@ -103,7 +118,7 @@ class WorkerContext: return self.callback def on_progress(step: int, timestep: int, latents: Any): - self.callback.step = step + # self.callback.step = step self.set_progress( step, stages=self.callback.stage, @@ -138,9 +153,9 @@ class WorkerContext: self.job_type, self.device.device, JobStatus.RUNNING, - steps=steps, - stages=stages, - tiles=tiles, + steps=Progress(steps, self.callback.steps.total), + stages=Progress(stages, self.callback.stages.total), + tiles=Progress(tiles, self.callback.tiles.total), result=result, ) self.progress.put( @@ -158,9 +173,9 @@ class WorkerContext: self.job_type, self.device.device, JobStatus.SUCCESS, - steps=self.callback.step, - stages=self.callback.stage, - tiles=self.callback.tile, + steps=self.callback.steps, + stages=self.callback.stages, + tiles=self.callback.tiles, result=self.callback.result, ) self.progress.put( @@ -179,7 +194,10 @@ class WorkerContext: self.job_type, self.device.device, JobStatus.FAILED, - steps=self.get_progress(), + steps=self.get_last_steps(), + stages=self.get_last_stages(), + tiles=self.get_last_tiles(), + # TODO: should this have results? ) self.progress.put( self.last_progress,