1
0
Fork 0

use progress type in command

This commit is contained in:
Sean Sube 2024-01-04 19:09:52 -06:00
parent ce84dfa115
commit b6da935be6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 99 additions and 54 deletions

View File

@ -5,6 +5,8 @@ from typing import Any, List, Optional, Tuple
from PIL import Image
from ..worker.command import Progress
from ..errors import CancelledException, RetryException
from ..output import save_image
from ..params import ImageParams, Size, StageParams
@ -23,31 +25,43 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
parent: ProgressCallback
step: int
total: int
stage: int
tile: int
step: int # same as steps.current, left for legacy purposes
prev: int # accumulator when step resets
# new progress trackers
steps: Progress
stages: Progress
tiles: Progress
result: Optional[StageResult]
# TODO: total stages and tiles
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
self.stage = 0
self.tile = 0
self.prev = 0
self.steps = Progress(self.step, self.prev)
self.stages = Progress(0, 0)
self.tiles = Progress(0, 0)
self.result = None
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.prev += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
total = self.get_total()
self.steps.current = total
self.parent(total, timestep, latents)
def get_total(self) -> int:
return self.step + self.total
return self.step + self.prev
def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
self.prev = steps
self.steps.total = steps
self.stages.total = stages
self.tiles.total = tiles
@classmethod
def from_progress(cls, parent: ProgressCallback):
@ -61,6 +75,8 @@ class ChainPipeline:
tiles as needed.
"""
stages: List[PipelineStage]
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
@ -124,6 +140,11 @@ class ChainPipeline:
if not isinstance(callback, ChainProgress):
callback = ChainProgress.from_progress(callback)
# set estimated totals
callback.set_total(
self.steps(params, sources.size), stages=len(self.stages), tiles=0
)
start = monotonic()
if len(sources) > 0:
@ -145,7 +166,7 @@ class ChainPipeline:
len(stage_sources),
kwargs.keys(),
)
callback.stage = stage_i
callback.stages.current = stage_i
per_stage_params = params
if "params" in kwargs:
@ -169,7 +190,7 @@ class ChainPipeline:
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
callback.tile = 0 # reset this either way
callback.tiles.current = 0 # reset this either way
if must_tile:
logger.info(
"image contains sources or is larger than tile size of %s, tiling stage",
@ -188,7 +209,9 @@ class ChainPipeline:
server,
stage_params,
per_stage_params,
StageResult(images=source_tile, metadata=stage_sources.metadata),
StageResult(
images=source_tile, metadata=stage_sources.metadata
),
tile_mask=tile_mask,
callback=callback,
dims=dims,
@ -199,7 +222,7 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image)
callback.tile = callback.tile + 1
callback.tiles.current = callback.tiles.current + 1
return tile_result
except CancelledException as err:
@ -226,7 +249,9 @@ class ChainPipeline:
**kwargs,
)
stage_sources = StageResult(images=stage_results)
stage_sources = StageResult(
images=stage_results, metadata=stage_sources.metadata
)
else:
logger.debug(
"image does not contain sources and is within tile size of %s, running stage",

View File

@ -125,7 +125,7 @@ def run_txt2img_pipeline(
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert_image(0, thumbnail)
images.insert_image(0, thumbnail, images.metadata[0])
save_result(server, images, worker.job)

View File

@ -13,24 +13,6 @@ Param = Union[str, int, float]
Point = Tuple[int, int]
class Progress:
current: int
total: int
def __init__(self, current: int, total: int) -> None:
self.current = current
self.total = total
def __str__(self) -> str:
return "%s/%s" % (self.current, self.total)
def tojson(self):
return {
"current": self.current,
"total": self.total,
}
class SizeChart(IntEnum):
micro = 64
mini = 128 # small tile for very expensive models

View File

@ -21,15 +21,33 @@ class JobType(str, Enum):
CHAIN = "chain"
class Progress:
current: int
total: int
def __init__(self, current: int, total: int) -> None:
self.current = current
self.total = total
def __str__(self) -> str:
return "%s/%s" % (self.current, self.total)
def tojson(self):
return {
"current": self.current,
"total": self.total,
}
class ProgressCommand:
device: str
job: str
job_type: str
status: JobStatus
result: Any
steps: int
stages: int
tiles: int
result: Any # really StageResult but that would be a very circular import
steps: Progress
stages: Progress
tiles: Progress
def __init__(
self,
@ -37,19 +55,21 @@ class ProgressCommand:
job_type: str,
device: str,
status: JobStatus,
steps: Progress,
stages: Progress,
tiles: Progress,
result: Any = None,
steps: int = 0,
stages: int = 0,
tiles: int = 0,
):
self.job = job
self.job_type = job_type
self.device = device
self.status = status
self.result = result
# progress info
self.steps = steps
self.stages = stages
self.tiles = tiles
self.result = result
class JobCommand:

View File

@ -7,7 +7,7 @@ from torch.multiprocessing import Queue, Value
from ..errors import CancelledException
from ..params import DeviceParams
from .command import JobCommand, JobStatus, ProgressCommand
from .command import JobCommand, JobStatus, Progress, ProgressCommand
logger = getLogger(__name__)
@ -90,11 +90,26 @@ class WorkerContext:
"""
return self.device
def get_progress(self) -> int:
def get_progress(self) -> Progress:
return self.get_last_steps()
def get_last_steps(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.steps
return 0
return Progress(0, 0)
def get_last_stages(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.stages
return Progress(0, 0)
def get_last_tiles(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.tiles
return Progress(0, 0)
def get_progress_callback(self, reset=False) -> ProgressCallback:
from ..chain.pipeline import ChainProgress
@ -103,7 +118,7 @@ class WorkerContext:
return self.callback
def on_progress(step: int, timestep: int, latents: Any):
self.callback.step = step
# self.callback.step = step
self.set_progress(
step,
stages=self.callback.stage,
@ -138,9 +153,9 @@ class WorkerContext:
self.job_type,
self.device.device,
JobStatus.RUNNING,
steps=steps,
stages=stages,
tiles=tiles,
steps=Progress(steps, self.callback.steps.total),
stages=Progress(stages, self.callback.stages.total),
tiles=Progress(tiles, self.callback.tiles.total),
result=result,
)
self.progress.put(
@ -158,9 +173,9 @@ class WorkerContext:
self.job_type,
self.device.device,
JobStatus.SUCCESS,
steps=self.callback.step,
stages=self.callback.stage,
tiles=self.callback.tile,
steps=self.callback.steps,
stages=self.callback.stages,
tiles=self.callback.tiles,
result=self.callback.result,
)
self.progress.put(
@ -179,7 +194,10 @@ class WorkerContext:
self.job_type,
self.device.device,
JobStatus.FAILED,
steps=self.get_progress(),
steps=self.get_last_steps(),
stages=self.get_last_stages(),
tiles=self.get_last_tiles(),
# TODO: should this have results?
)
self.progress.put(
self.last_progress,