use progress type in command
This commit is contained in:
parent
ce84dfa115
commit
b6da935be6
|
@ -5,6 +5,8 @@ from typing import Any, List, Optional, Tuple
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..worker.command import Progress
|
||||
|
||||
from ..errors import CancelledException, RetryException
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
|
@ -23,31 +25,43 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
|||
|
||||
class ChainProgress:
|
||||
parent: ProgressCallback
|
||||
step: int
|
||||
total: int
|
||||
stage: int
|
||||
tile: int
|
||||
step: int # same as steps.current, left for legacy purposes
|
||||
prev: int # accumulator when step resets
|
||||
|
||||
# new progress trackers
|
||||
steps: Progress
|
||||
stages: Progress
|
||||
tiles: Progress
|
||||
result: Optional[StageResult]
|
||||
# TODO: total stages and tiles
|
||||
|
||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||
self.parent = parent
|
||||
self.step = start
|
||||
self.total = 0
|
||||
self.stage = 0
|
||||
self.tile = 0
|
||||
self.prev = 0
|
||||
self.steps = Progress(self.step, self.prev)
|
||||
self.stages = Progress(0, 0)
|
||||
self.tiles = Progress(0, 0)
|
||||
self.result = None
|
||||
|
||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||
if step < self.step:
|
||||
# accumulate on resets
|
||||
self.total += self.step
|
||||
self.prev += self.step
|
||||
|
||||
self.step = step
|
||||
self.parent(self.get_total(), timestep, latents)
|
||||
|
||||
total = self.get_total()
|
||||
self.steps.current = total
|
||||
self.parent(total, timestep, latents)
|
||||
|
||||
def get_total(self) -> int:
|
||||
return self.step + self.total
|
||||
return self.step + self.prev
|
||||
|
||||
def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
||||
self.prev = steps
|
||||
self.steps.total = steps
|
||||
self.stages.total = stages
|
||||
self.tiles.total = tiles
|
||||
|
||||
@classmethod
|
||||
def from_progress(cls, parent: ProgressCallback):
|
||||
|
@ -61,6 +75,8 @@ class ChainPipeline:
|
|||
tiles as needed.
|
||||
"""
|
||||
|
||||
stages: List[PipelineStage]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stages: Optional[List[PipelineStage]] = None,
|
||||
|
@ -124,6 +140,11 @@ class ChainPipeline:
|
|||
if not isinstance(callback, ChainProgress):
|
||||
callback = ChainProgress.from_progress(callback)
|
||||
|
||||
# set estimated totals
|
||||
callback.set_total(
|
||||
self.steps(params, sources.size), stages=len(self.stages), tiles=0
|
||||
)
|
||||
|
||||
start = monotonic()
|
||||
|
||||
if len(sources) > 0:
|
||||
|
@ -145,7 +166,7 @@ class ChainPipeline:
|
|||
len(stage_sources),
|
||||
kwargs.keys(),
|
||||
)
|
||||
callback.stage = stage_i
|
||||
callback.stages.current = stage_i
|
||||
|
||||
per_stage_params = params
|
||||
if "params" in kwargs:
|
||||
|
@ -169,7 +190,7 @@ class ChainPipeline:
|
|||
if stage_pipe.max_tile > 0:
|
||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||
|
||||
callback.tile = 0 # reset this either way
|
||||
callback.tiles.current = 0 # reset this either way
|
||||
if must_tile:
|
||||
logger.info(
|
||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||
|
@ -188,7 +209,9 @@ class ChainPipeline:
|
|||
server,
|
||||
stage_params,
|
||||
per_stage_params,
|
||||
StageResult(images=source_tile, metadata=stage_sources.metadata),
|
||||
StageResult(
|
||||
images=source_tile, metadata=stage_sources.metadata
|
||||
),
|
||||
tile_mask=tile_mask,
|
||||
callback=callback,
|
||||
dims=dims,
|
||||
|
@ -199,7 +222,7 @@ class ChainPipeline:
|
|||
for j, image in enumerate(tile_result.as_image()):
|
||||
save_image(server, f"last-tile-{j}.png", image)
|
||||
|
||||
callback.tile = callback.tile + 1
|
||||
callback.tiles.current = callback.tiles.current + 1
|
||||
|
||||
return tile_result
|
||||
except CancelledException as err:
|
||||
|
@ -226,7 +249,9 @@ class ChainPipeline:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
stage_sources = StageResult(images=stage_results)
|
||||
stage_sources = StageResult(
|
||||
images=stage_results, metadata=stage_sources.metadata
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"image does not contain sources and is within tile size of %s, running stage",
|
||||
|
|
|
@ -125,7 +125,7 @@ def run_txt2img_pipeline(
|
|||
thumbnail = cover.copy()
|
||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||
|
||||
images.insert_image(0, thumbnail)
|
||||
images.insert_image(0, thumbnail, images.metadata[0])
|
||||
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
|
|
|
@ -13,24 +13,6 @@ Param = Union[str, int, float]
|
|||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class Progress:
|
||||
current: int
|
||||
total: int
|
||||
|
||||
def __init__(self, current: int, total: int) -> None:
|
||||
self.current = current
|
||||
self.total = total
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%s/%s" % (self.current, self.total)
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"current": self.current,
|
||||
"total": self.total,
|
||||
}
|
||||
|
||||
|
||||
class SizeChart(IntEnum):
|
||||
micro = 64
|
||||
mini = 128 # small tile for very expensive models
|
||||
|
|
|
@ -21,15 +21,33 @@ class JobType(str, Enum):
|
|||
CHAIN = "chain"
|
||||
|
||||
|
||||
class Progress:
|
||||
current: int
|
||||
total: int
|
||||
|
||||
def __init__(self, current: int, total: int) -> None:
|
||||
self.current = current
|
||||
self.total = total
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%s/%s" % (self.current, self.total)
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"current": self.current,
|
||||
"total": self.total,
|
||||
}
|
||||
|
||||
|
||||
class ProgressCommand:
|
||||
device: str
|
||||
job: str
|
||||
job_type: str
|
||||
status: JobStatus
|
||||
result: Any
|
||||
steps: int
|
||||
stages: int
|
||||
tiles: int
|
||||
result: Any # really StageResult but that would be a very circular import
|
||||
steps: Progress
|
||||
stages: Progress
|
||||
tiles: Progress
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -37,19 +55,21 @@ class ProgressCommand:
|
|||
job_type: str,
|
||||
device: str,
|
||||
status: JobStatus,
|
||||
steps: Progress,
|
||||
stages: Progress,
|
||||
tiles: Progress,
|
||||
result: Any = None,
|
||||
steps: int = 0,
|
||||
stages: int = 0,
|
||||
tiles: int = 0,
|
||||
):
|
||||
self.job = job
|
||||
self.job_type = job_type
|
||||
self.device = device
|
||||
self.status = status
|
||||
self.result = result
|
||||
|
||||
# progress info
|
||||
self.steps = steps
|
||||
self.stages = stages
|
||||
self.tiles = tiles
|
||||
self.result = result
|
||||
|
||||
|
||||
class JobCommand:
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.multiprocessing import Queue, Value
|
|||
|
||||
from ..errors import CancelledException
|
||||
from ..params import DeviceParams
|
||||
from .command import JobCommand, JobStatus, ProgressCommand
|
||||
from .command import JobCommand, JobStatus, Progress, ProgressCommand
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -90,11 +90,26 @@ class WorkerContext:
|
|||
"""
|
||||
return self.device
|
||||
|
||||
def get_progress(self) -> int:
|
||||
def get_progress(self) -> Progress:
|
||||
return self.get_last_steps()
|
||||
|
||||
def get_last_steps(self) -> Progress:
|
||||
if self.last_progress is not None:
|
||||
return self.last_progress.steps
|
||||
|
||||
return 0
|
||||
return Progress(0, 0)
|
||||
|
||||
def get_last_stages(self) -> Progress:
|
||||
if self.last_progress is not None:
|
||||
return self.last_progress.stages
|
||||
|
||||
return Progress(0, 0)
|
||||
|
||||
def get_last_tiles(self) -> Progress:
|
||||
if self.last_progress is not None:
|
||||
return self.last_progress.tiles
|
||||
|
||||
return Progress(0, 0)
|
||||
|
||||
def get_progress_callback(self, reset=False) -> ProgressCallback:
|
||||
from ..chain.pipeline import ChainProgress
|
||||
|
@ -103,7 +118,7 @@ class WorkerContext:
|
|||
return self.callback
|
||||
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
self.callback.step = step
|
||||
# self.callback.step = step
|
||||
self.set_progress(
|
||||
step,
|
||||
stages=self.callback.stage,
|
||||
|
@ -138,9 +153,9 @@ class WorkerContext:
|
|||
self.job_type,
|
||||
self.device.device,
|
||||
JobStatus.RUNNING,
|
||||
steps=steps,
|
||||
stages=stages,
|
||||
tiles=tiles,
|
||||
steps=Progress(steps, self.callback.steps.total),
|
||||
stages=Progress(stages, self.callback.stages.total),
|
||||
tiles=Progress(tiles, self.callback.tiles.total),
|
||||
result=result,
|
||||
)
|
||||
self.progress.put(
|
||||
|
@ -158,9 +173,9 @@ class WorkerContext:
|
|||
self.job_type,
|
||||
self.device.device,
|
||||
JobStatus.SUCCESS,
|
||||
steps=self.callback.step,
|
||||
stages=self.callback.stage,
|
||||
tiles=self.callback.tile,
|
||||
steps=self.callback.steps,
|
||||
stages=self.callback.stages,
|
||||
tiles=self.callback.tiles,
|
||||
result=self.callback.result,
|
||||
)
|
||||
self.progress.put(
|
||||
|
@ -179,7 +194,10 @@ class WorkerContext:
|
|||
self.job_type,
|
||||
self.device.device,
|
||||
JobStatus.FAILED,
|
||||
steps=self.get_progress(),
|
||||
steps=self.get_last_steps(),
|
||||
stages=self.get_last_stages(),
|
||||
tiles=self.get_last_tiles(),
|
||||
# TODO: should this have results?
|
||||
)
|
||||
self.progress.put(
|
||||
self.last_progress,
|
||||
|
|
Loading…
Reference in New Issue