use progress type in command
This commit is contained in:
parent
ce84dfa115
commit
b6da935be6
|
@ -5,6 +5,8 @@ from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..worker.command import Progress
|
||||||
|
|
||||||
from ..errors import CancelledException, RetryException
|
from ..errors import CancelledException, RetryException
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
@ -23,31 +25,43 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||||
|
|
||||||
class ChainProgress:
|
class ChainProgress:
|
||||||
parent: ProgressCallback
|
parent: ProgressCallback
|
||||||
step: int
|
step: int # same as steps.current, left for legacy purposes
|
||||||
total: int
|
prev: int # accumulator when step resets
|
||||||
stage: int
|
|
||||||
tile: int
|
# new progress trackers
|
||||||
|
steps: Progress
|
||||||
|
stages: Progress
|
||||||
|
tiles: Progress
|
||||||
result: Optional[StageResult]
|
result: Optional[StageResult]
|
||||||
# TODO: total stages and tiles
|
|
||||||
|
|
||||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.step = start
|
self.step = start
|
||||||
self.total = 0
|
self.prev = 0
|
||||||
self.stage = 0
|
self.steps = Progress(self.step, self.prev)
|
||||||
self.tile = 0
|
self.stages = Progress(0, 0)
|
||||||
|
self.tiles = Progress(0, 0)
|
||||||
self.result = None
|
self.result = None
|
||||||
|
|
||||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||||
if step < self.step:
|
if step < self.step:
|
||||||
# accumulate on resets
|
# accumulate on resets
|
||||||
self.total += self.step
|
self.prev += self.step
|
||||||
|
|
||||||
self.step = step
|
self.step = step
|
||||||
self.parent(self.get_total(), timestep, latents)
|
|
||||||
|
total = self.get_total()
|
||||||
|
self.steps.current = total
|
||||||
|
self.parent(total, timestep, latents)
|
||||||
|
|
||||||
def get_total(self) -> int:
|
def get_total(self) -> int:
|
||||||
return self.step + self.total
|
return self.step + self.prev
|
||||||
|
|
||||||
|
def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
||||||
|
self.prev = steps
|
||||||
|
self.steps.total = steps
|
||||||
|
self.stages.total = stages
|
||||||
|
self.tiles.total = tiles
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_progress(cls, parent: ProgressCallback):
|
def from_progress(cls, parent: ProgressCallback):
|
||||||
|
@ -61,6 +75,8 @@ class ChainPipeline:
|
||||||
tiles as needed.
|
tiles as needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
stages: List[PipelineStage]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stages: Optional[List[PipelineStage]] = None,
|
stages: Optional[List[PipelineStage]] = None,
|
||||||
|
@ -124,6 +140,11 @@ class ChainPipeline:
|
||||||
if not isinstance(callback, ChainProgress):
|
if not isinstance(callback, ChainProgress):
|
||||||
callback = ChainProgress.from_progress(callback)
|
callback = ChainProgress.from_progress(callback)
|
||||||
|
|
||||||
|
# set estimated totals
|
||||||
|
callback.set_total(
|
||||||
|
self.steps(params, sources.size), stages=len(self.stages), tiles=0
|
||||||
|
)
|
||||||
|
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
|
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
|
@ -145,7 +166,7 @@ class ChainPipeline:
|
||||||
len(stage_sources),
|
len(stage_sources),
|
||||||
kwargs.keys(),
|
kwargs.keys(),
|
||||||
)
|
)
|
||||||
callback.stage = stage_i
|
callback.stages.current = stage_i
|
||||||
|
|
||||||
per_stage_params = params
|
per_stage_params = params
|
||||||
if "params" in kwargs:
|
if "params" in kwargs:
|
||||||
|
@ -169,7 +190,7 @@ class ChainPipeline:
|
||||||
if stage_pipe.max_tile > 0:
|
if stage_pipe.max_tile > 0:
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
callback.tile = 0 # reset this either way
|
callback.tiles.current = 0 # reset this either way
|
||||||
if must_tile:
|
if must_tile:
|
||||||
logger.info(
|
logger.info(
|
||||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
|
@ -188,7 +209,9 @@ class ChainPipeline:
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
per_stage_params,
|
per_stage_params,
|
||||||
StageResult(images=source_tile, metadata=stage_sources.metadata),
|
StageResult(
|
||||||
|
images=source_tile, metadata=stage_sources.metadata
|
||||||
|
),
|
||||||
tile_mask=tile_mask,
|
tile_mask=tile_mask,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
|
@ -199,7 +222,7 @@ class ChainPipeline:
|
||||||
for j, image in enumerate(tile_result.as_image()):
|
for j, image in enumerate(tile_result.as_image()):
|
||||||
save_image(server, f"last-tile-{j}.png", image)
|
save_image(server, f"last-tile-{j}.png", image)
|
||||||
|
|
||||||
callback.tile = callback.tile + 1
|
callback.tiles.current = callback.tiles.current + 1
|
||||||
|
|
||||||
return tile_result
|
return tile_result
|
||||||
except CancelledException as err:
|
except CancelledException as err:
|
||||||
|
@ -226,7 +249,9 @@ class ChainPipeline:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
stage_sources = StageResult(images=stage_results)
|
stage_sources = StageResult(
|
||||||
|
images=stage_results, metadata=stage_sources.metadata
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"image does not contain sources and is within tile size of %s, running stage",
|
"image does not contain sources and is within tile size of %s, running stage",
|
||||||
|
|
|
@ -125,7 +125,7 @@ def run_txt2img_pipeline(
|
||||||
thumbnail = cover.copy()
|
thumbnail = cover.copy()
|
||||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||||
|
|
||||||
images.insert_image(0, thumbnail)
|
images.insert_image(0, thumbnail, images.metadata[0])
|
||||||
|
|
||||||
save_result(server, images, worker.job)
|
save_result(server, images, worker.job)
|
||||||
|
|
||||||
|
|
|
@ -13,24 +13,6 @@ Param = Union[str, int, float]
|
||||||
Point = Tuple[int, int]
|
Point = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
class Progress:
|
|
||||||
current: int
|
|
||||||
total: int
|
|
||||||
|
|
||||||
def __init__(self, current: int, total: int) -> None:
|
|
||||||
self.current = current
|
|
||||||
self.total = total
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "%s/%s" % (self.current, self.total)
|
|
||||||
|
|
||||||
def tojson(self):
|
|
||||||
return {
|
|
||||||
"current": self.current,
|
|
||||||
"total": self.total,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
micro = 64
|
micro = 64
|
||||||
mini = 128 # small tile for very expensive models
|
mini = 128 # small tile for very expensive models
|
||||||
|
|
|
@ -21,15 +21,33 @@ class JobType(str, Enum):
|
||||||
CHAIN = "chain"
|
CHAIN = "chain"
|
||||||
|
|
||||||
|
|
||||||
|
class Progress:
|
||||||
|
current: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
def __init__(self, current: int, total: int) -> None:
|
||||||
|
self.current = current
|
||||||
|
self.total = total
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "%s/%s" % (self.current, self.total)
|
||||||
|
|
||||||
|
def tojson(self):
|
||||||
|
return {
|
||||||
|
"current": self.current,
|
||||||
|
"total": self.total,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ProgressCommand:
|
class ProgressCommand:
|
||||||
device: str
|
device: str
|
||||||
job: str
|
job: str
|
||||||
job_type: str
|
job_type: str
|
||||||
status: JobStatus
|
status: JobStatus
|
||||||
result: Any
|
result: Any # really StageResult but that would be a very circular import
|
||||||
steps: int
|
steps: Progress
|
||||||
stages: int
|
stages: Progress
|
||||||
tiles: int
|
tiles: Progress
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -37,19 +55,21 @@ class ProgressCommand:
|
||||||
job_type: str,
|
job_type: str,
|
||||||
device: str,
|
device: str,
|
||||||
status: JobStatus,
|
status: JobStatus,
|
||||||
|
steps: Progress,
|
||||||
|
stages: Progress,
|
||||||
|
tiles: Progress,
|
||||||
result: Any = None,
|
result: Any = None,
|
||||||
steps: int = 0,
|
|
||||||
stages: int = 0,
|
|
||||||
tiles: int = 0,
|
|
||||||
):
|
):
|
||||||
self.job = job
|
self.job = job
|
||||||
self.job_type = job_type
|
self.job_type = job_type
|
||||||
self.device = device
|
self.device = device
|
||||||
self.status = status
|
self.status = status
|
||||||
self.result = result
|
|
||||||
|
# progress info
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.stages = stages
|
self.stages = stages
|
||||||
self.tiles = tiles
|
self.tiles = tiles
|
||||||
|
self.result = result
|
||||||
|
|
||||||
|
|
||||||
class JobCommand:
|
class JobCommand:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch.multiprocessing import Queue, Value
|
||||||
|
|
||||||
from ..errors import CancelledException
|
from ..errors import CancelledException
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from .command import JobCommand, JobStatus, ProgressCommand
|
from .command import JobCommand, JobStatus, Progress, ProgressCommand
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -90,11 +90,26 @@ class WorkerContext:
|
||||||
"""
|
"""
|
||||||
return self.device
|
return self.device
|
||||||
|
|
||||||
def get_progress(self) -> int:
|
def get_progress(self) -> Progress:
|
||||||
|
return self.get_last_steps()
|
||||||
|
|
||||||
|
def get_last_steps(self) -> Progress:
|
||||||
if self.last_progress is not None:
|
if self.last_progress is not None:
|
||||||
return self.last_progress.steps
|
return self.last_progress.steps
|
||||||
|
|
||||||
return 0
|
return Progress(0, 0)
|
||||||
|
|
||||||
|
def get_last_stages(self) -> Progress:
|
||||||
|
if self.last_progress is not None:
|
||||||
|
return self.last_progress.stages
|
||||||
|
|
||||||
|
return Progress(0, 0)
|
||||||
|
|
||||||
|
def get_last_tiles(self) -> Progress:
|
||||||
|
if self.last_progress is not None:
|
||||||
|
return self.last_progress.tiles
|
||||||
|
|
||||||
|
return Progress(0, 0)
|
||||||
|
|
||||||
def get_progress_callback(self, reset=False) -> ProgressCallback:
|
def get_progress_callback(self, reset=False) -> ProgressCallback:
|
||||||
from ..chain.pipeline import ChainProgress
|
from ..chain.pipeline import ChainProgress
|
||||||
|
@ -103,7 +118,7 @@ class WorkerContext:
|
||||||
return self.callback
|
return self.callback
|
||||||
|
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
self.callback.step = step
|
# self.callback.step = step
|
||||||
self.set_progress(
|
self.set_progress(
|
||||||
step,
|
step,
|
||||||
stages=self.callback.stage,
|
stages=self.callback.stage,
|
||||||
|
@ -138,9 +153,9 @@ class WorkerContext:
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.RUNNING,
|
JobStatus.RUNNING,
|
||||||
steps=steps,
|
steps=Progress(steps, self.callback.steps.total),
|
||||||
stages=stages,
|
stages=Progress(stages, self.callback.stages.total),
|
||||||
tiles=tiles,
|
tiles=Progress(tiles, self.callback.tiles.total),
|
||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
|
@ -158,9 +173,9 @@ class WorkerContext:
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.SUCCESS,
|
JobStatus.SUCCESS,
|
||||||
steps=self.callback.step,
|
steps=self.callback.steps,
|
||||||
stages=self.callback.stage,
|
stages=self.callback.stages,
|
||||||
tiles=self.callback.tile,
|
tiles=self.callback.tiles,
|
||||||
result=self.callback.result,
|
result=self.callback.result,
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
|
@ -179,7 +194,10 @@ class WorkerContext:
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.FAILED,
|
JobStatus.FAILED,
|
||||||
steps=self.get_progress(),
|
steps=self.get_last_steps(),
|
||||||
|
stages=self.get_last_stages(),
|
||||||
|
tiles=self.get_last_tiles(),
|
||||||
|
# TODO: should this have results?
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
|
Loading…
Reference in New Issue