1
0
Fork 0

use progress type in command

This commit is contained in:
Sean Sube 2024-01-04 19:09:52 -06:00
parent ce84dfa115
commit b6da935be6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 99 additions and 54 deletions

View File

@ -5,6 +5,8 @@ from typing import Any, List, Optional, Tuple
from PIL import Image from PIL import Image
from ..worker.command import Progress
from ..errors import CancelledException, RetryException from ..errors import CancelledException, RetryException
from ..output import save_image from ..output import save_image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
@ -23,31 +25,43 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress: class ChainProgress:
parent: ProgressCallback parent: ProgressCallback
step: int step: int # same as steps.current, left for legacy purposes
total: int prev: int # accumulator when step resets
stage: int
tile: int # new progress trackers
steps: Progress
stages: Progress
tiles: Progress
result: Optional[StageResult] result: Optional[StageResult]
# TODO: total stages and tiles
def __init__(self, parent: ProgressCallback, start=0) -> None: def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent self.parent = parent
self.step = start self.step = start
self.total = 0 self.prev = 0
self.stage = 0 self.steps = Progress(self.step, self.prev)
self.tile = 0 self.stages = Progress(0, 0)
self.tiles = Progress(0, 0)
self.result = None self.result = None
def __call__(self, step: int, timestep: int, latents: Any) -> None: def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step: if step < self.step:
# accumulate on resets # accumulate on resets
self.total += self.step self.prev += self.step
self.step = step self.step = step
self.parent(self.get_total(), timestep, latents)
total = self.get_total()
self.steps.current = total
self.parent(total, timestep, latents)
def get_total(self) -> int: def get_total(self) -> int:
return self.step + self.total return self.step + self.prev
def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
self.prev = steps
self.steps.total = steps
self.stages.total = stages
self.tiles.total = tiles
@classmethod @classmethod
def from_progress(cls, parent: ProgressCallback): def from_progress(cls, parent: ProgressCallback):
@ -61,6 +75,8 @@ class ChainPipeline:
tiles as needed. tiles as needed.
""" """
stages: List[PipelineStage]
def __init__( def __init__(
self, self,
stages: Optional[List[PipelineStage]] = None, stages: Optional[List[PipelineStage]] = None,
@ -124,6 +140,11 @@ class ChainPipeline:
if not isinstance(callback, ChainProgress): if not isinstance(callback, ChainProgress):
callback = ChainProgress.from_progress(callback) callback = ChainProgress.from_progress(callback)
# set estimated totals
callback.set_total(
self.steps(params, sources.size), stages=len(self.stages), tiles=0
)
start = monotonic() start = monotonic()
if len(sources) > 0: if len(sources) > 0:
@ -145,7 +166,7 @@ class ChainPipeline:
len(stage_sources), len(stage_sources),
kwargs.keys(), kwargs.keys(),
) )
callback.stage = stage_i callback.stages.current = stage_i
per_stage_params = params per_stage_params = params
if "params" in kwargs: if "params" in kwargs:
@ -169,7 +190,7 @@ class ChainPipeline:
if stage_pipe.max_tile > 0: if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size) tile = min(stage_pipe.max_tile, stage_params.tile_size)
callback.tile = 0 # reset this either way callback.tiles.current = 0 # reset this either way
if must_tile: if must_tile:
logger.info( logger.info(
"image contains sources or is larger than tile size of %s, tiling stage", "image contains sources or is larger than tile size of %s, tiling stage",
@ -188,7 +209,9 @@ class ChainPipeline:
server, server,
stage_params, stage_params,
per_stage_params, per_stage_params,
StageResult(images=source_tile, metadata=stage_sources.metadata), StageResult(
images=source_tile, metadata=stage_sources.metadata
),
tile_mask=tile_mask, tile_mask=tile_mask,
callback=callback, callback=callback,
dims=dims, dims=dims,
@ -199,7 +222,7 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()): for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image) save_image(server, f"last-tile-{j}.png", image)
callback.tile = callback.tile + 1 callback.tiles.current = callback.tiles.current + 1
return tile_result return tile_result
except CancelledException as err: except CancelledException as err:
@ -226,7 +249,9 @@ class ChainPipeline:
**kwargs, **kwargs,
) )
stage_sources = StageResult(images=stage_results) stage_sources = StageResult(
images=stage_results, metadata=stage_sources.metadata
)
else: else:
logger.debug( logger.debug(
"image does not contain sources and is within tile size of %s, running stage", "image does not contain sources and is within tile size of %s, running stage",

View File

@ -125,7 +125,7 @@ def run_txt2img_pipeline(
thumbnail = cover.copy() thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert_image(0, thumbnail) images.insert_image(0, thumbnail, images.metadata[0])
save_result(server, images, worker.job) save_result(server, images, worker.job)

View File

@ -13,24 +13,6 @@ Param = Union[str, int, float]
Point = Tuple[int, int] Point = Tuple[int, int]
class Progress:
current: int
total: int
def __init__(self, current: int, total: int) -> None:
self.current = current
self.total = total
def __str__(self) -> str:
return "%s/%s" % (self.current, self.total)
def tojson(self):
return {
"current": self.current,
"total": self.total,
}
class SizeChart(IntEnum): class SizeChart(IntEnum):
micro = 64 micro = 64
mini = 128 # small tile for very expensive models mini = 128 # small tile for very expensive models

View File

@ -21,15 +21,33 @@ class JobType(str, Enum):
CHAIN = "chain" CHAIN = "chain"
class Progress:
current: int
total: int
def __init__(self, current: int, total: int) -> None:
self.current = current
self.total = total
def __str__(self) -> str:
return "%s/%s" % (self.current, self.total)
def tojson(self):
return {
"current": self.current,
"total": self.total,
}
class ProgressCommand: class ProgressCommand:
device: str device: str
job: str job: str
job_type: str job_type: str
status: JobStatus status: JobStatus
result: Any result: Any # really StageResult but that would be a very circular import
steps: int steps: Progress
stages: int stages: Progress
tiles: int tiles: Progress
def __init__( def __init__(
self, self,
@ -37,19 +55,21 @@ class ProgressCommand:
job_type: str, job_type: str,
device: str, device: str,
status: JobStatus, status: JobStatus,
steps: Progress,
stages: Progress,
tiles: Progress,
result: Any = None, result: Any = None,
steps: int = 0,
stages: int = 0,
tiles: int = 0,
): ):
self.job = job self.job = job
self.job_type = job_type self.job_type = job_type
self.device = device self.device = device
self.status = status self.status = status
self.result = result
# progress info
self.steps = steps self.steps = steps
self.stages = stages self.stages = stages
self.tiles = tiles self.tiles = tiles
self.result = result
class JobCommand: class JobCommand:

View File

@ -7,7 +7,7 @@ from torch.multiprocessing import Queue, Value
from ..errors import CancelledException from ..errors import CancelledException
from ..params import DeviceParams from ..params import DeviceParams
from .command import JobCommand, JobStatus, ProgressCommand from .command import JobCommand, JobStatus, Progress, ProgressCommand
logger = getLogger(__name__) logger = getLogger(__name__)
@ -90,11 +90,26 @@ class WorkerContext:
""" """
return self.device return self.device
def get_progress(self) -> int: def get_progress(self) -> Progress:
return self.get_last_steps()
def get_last_steps(self) -> Progress:
if self.last_progress is not None: if self.last_progress is not None:
return self.last_progress.steps return self.last_progress.steps
return 0 return Progress(0, 0)
def get_last_stages(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.stages
return Progress(0, 0)
def get_last_tiles(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.tiles
return Progress(0, 0)
def get_progress_callback(self, reset=False) -> ProgressCallback: def get_progress_callback(self, reset=False) -> ProgressCallback:
from ..chain.pipeline import ChainProgress from ..chain.pipeline import ChainProgress
@ -103,7 +118,7 @@ class WorkerContext:
return self.callback return self.callback
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
self.callback.step = step # self.callback.step = step
self.set_progress( self.set_progress(
step, step,
stages=self.callback.stage, stages=self.callback.stage,
@ -138,9 +153,9 @@ class WorkerContext:
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.RUNNING, JobStatus.RUNNING,
steps=steps, steps=Progress(steps, self.callback.steps.total),
stages=stages, stages=Progress(stages, self.callback.stages.total),
tiles=tiles, tiles=Progress(tiles, self.callback.tiles.total),
result=result, result=result,
) )
self.progress.put( self.progress.put(
@ -158,9 +173,9 @@ class WorkerContext:
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.SUCCESS, JobStatus.SUCCESS,
steps=self.callback.step, steps=self.callback.steps,
stages=self.callback.stage, stages=self.callback.stages,
tiles=self.callback.tile, tiles=self.callback.tiles,
result=self.callback.result, result=self.callback.result,
) )
self.progress.put( self.progress.put(
@ -179,7 +194,10 @@ class WorkerContext:
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.FAILED, JobStatus.FAILED,
steps=self.get_progress(), steps=self.get_last_steps(),
stages=self.get_last_stages(),
tiles=self.get_last_tiles(),
# TODO: should this have results?
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,