feat(api): add chain pipeline stage result type
This commit is contained in:
parent
c8dd85e798
commit
d52c68d607
|
@ -1,49 +1,2 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageParams
|
from .pipeline import ChainPipeline, PipelineStage, StageParams
|
||||||
from .blend_denoise import BlendDenoiseStage
|
from .stages import *
|
||||||
from .blend_img2img import BlendImg2ImgStage
|
|
||||||
from .blend_grid import BlendGridStage
|
|
||||||
from .blend_linear import BlendLinearStage
|
|
||||||
from .blend_mask import BlendMaskStage
|
|
||||||
from .correct_codeformer import CorrectCodeformerStage
|
|
||||||
from .correct_gfpgan import CorrectGFPGANStage
|
|
||||||
from .persist_disk import PersistDiskStage
|
|
||||||
from .persist_s3 import PersistS3Stage
|
|
||||||
from .reduce_crop import ReduceCropStage
|
|
||||||
from .reduce_thumbnail import ReduceThumbnailStage
|
|
||||||
from .source_noise import SourceNoiseStage
|
|
||||||
from .source_s3 import SourceS3Stage
|
|
||||||
from .source_txt2img import SourceTxt2ImgStage
|
|
||||||
from .source_url import SourceURLStage
|
|
||||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
|
||||||
from .upscale_highres import UpscaleHighresStage
|
|
||||||
from .upscale_outpaint import UpscaleOutpaintStage
|
|
||||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
|
||||||
from .upscale_simple import UpscaleSimpleStage
|
|
||||||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
|
||||||
from .upscale_swinir import UpscaleSwinIRStage
|
|
||||||
|
|
||||||
CHAIN_STAGES = {
|
|
||||||
"blend-denoise": BlendDenoiseStage,
|
|
||||||
"blend-img2img": BlendImg2ImgStage,
|
|
||||||
"blend-inpaint": UpscaleOutpaintStage,
|
|
||||||
"blend-grid": BlendGridStage,
|
|
||||||
"blend-linear": BlendLinearStage,
|
|
||||||
"blend-mask": BlendMaskStage,
|
|
||||||
"correct-codeformer": CorrectCodeformerStage,
|
|
||||||
"correct-gfpgan": CorrectGFPGANStage,
|
|
||||||
"persist-disk": PersistDiskStage,
|
|
||||||
"persist-s3": PersistS3Stage,
|
|
||||||
"reduce-crop": ReduceCropStage,
|
|
||||||
"reduce-thumbnail": ReduceThumbnailStage,
|
|
||||||
"source-noise": SourceNoiseStage,
|
|
||||||
"source-s3": SourceS3Stage,
|
|
||||||
"source-txt2img": SourceTxt2ImgStage,
|
|
||||||
"source-url": SourceURLStage,
|
|
||||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
|
||||||
"upscale-highres": UpscaleHighresStage,
|
|
||||||
"upscale-outpaint": UpscaleOutpaintStage,
|
|
||||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
|
||||||
"upscale-simple": UpscaleSimpleStage,
|
|
||||||
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
|
||||||
"upscale-swinir": UpscaleSwinIRStage,
|
|
||||||
}
|
|
|
@ -1,283 +1,39 @@
|
||||||
from datetime import timedelta
|
from typing import List, Optional
|
||||||
from logging import getLogger
|
|
||||||
from time import monotonic
|
|
||||||
from typing import Any, List, Optional, Tuple
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..errors import RetryException
|
from .result import StageResult
|
||||||
from ..output import save_image
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..server.context import ServerContext
|
||||||
from ..server import ServerContext
|
from ..worker.context import WorkerContext
|
||||||
from ..utils import is_debug, run_gc
|
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
|
||||||
from .stage import BaseStage
|
|
||||||
from .tile import needs_tile, process_tile_order
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
class BaseStage:
|
||||||
|
max_tile = SizeChart.auto
|
||||||
|
|
||||||
class ChainProgress:
|
|
||||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
|
||||||
self.parent = parent
|
|
||||||
self.step = start
|
|
||||||
self.total = 0
|
|
||||||
|
|
||||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
|
||||||
if step < self.step:
|
|
||||||
# accumulate on resets
|
|
||||||
self.total += self.step
|
|
||||||
|
|
||||||
self.step = step
|
|
||||||
self.parent(self.get_total(), timestep, latents)
|
|
||||||
|
|
||||||
def get_total(self) -> int:
|
|
||||||
return self.step + self.total
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_progress(cls, parent: ProgressCallback):
|
|
||||||
start = parent.step if hasattr(parent, "step") else 0
|
|
||||||
return ChainProgress(parent, start=start)
|
|
||||||
|
|
||||||
|
|
||||||
class ChainPipeline:
|
|
||||||
"""
|
|
||||||
Run many stages in series, passing the image results from each to the next, and processing
|
|
||||||
tiles as needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stages: Optional[List[PipelineStage]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new pipeline that will run the given stages.
|
|
||||||
"""
|
|
||||||
self.stages = list(stages or [])
|
|
||||||
|
|
||||||
def append(self, stage: Optional[PipelineStage]):
|
|
||||||
"""
|
|
||||||
Append an additional stage to this pipeline.
|
|
||||||
|
|
||||||
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
|
|
||||||
assemble the stage from loose arguments.
|
|
||||||
"""
|
|
||||||
if stage is not None:
|
|
||||||
self.stages.append(stage)
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
worker: WorkerContext,
|
_worker: WorkerContext,
|
||||||
server: ServerContext,
|
_server: ServerContext,
|
||||||
params: ImageParams,
|
_stage: StageParams,
|
||||||
sources: List[Image.Image],
|
_params: ImageParams,
|
||||||
callback: Optional[ProgressCallback],
|
_sources: List[Image.Image],
|
||||||
**kwargs
|
*args,
|
||||||
) -> List[Image.Image]:
|
stage_source: Optional[Image.Image] = None,
|
||||||
return self(
|
**kwargs,
|
||||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
) -> StageResult:
|
||||||
)
|
raise NotImplementedError() # noqa
|
||||||
|
|
||||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
def steps(
|
||||||
self.stages.append((callback, params, kwargs))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def steps(self, params: ImageParams, size: Size):
|
|
||||||
steps = 0
|
|
||||||
for callback, _params, kwargs in self.stages:
|
|
||||||
steps += callback.steps(kwargs.get("params", params), size)
|
|
||||||
|
|
||||||
return steps
|
|
||||||
|
|
||||||
def outputs(self, params: ImageParams, sources: int):
|
|
||||||
outputs = sources
|
|
||||||
for callback, _params, kwargs in self.stages:
|
|
||||||
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
self,
|
||||||
worker: WorkerContext,
|
_params: ImageParams,
|
||||||
server: ServerContext,
|
_size: Size,
|
||||||
params: ImageParams,
|
) -> int:
|
||||||
sources: List[Image.Image],
|
return 1 # noqa
|
||||||
callback: Optional[ProgressCallback] = None,
|
|
||||||
**pipeline_kwargs
|
|
||||||
) -> List[Image.Image]:
|
|
||||||
"""
|
|
||||||
DEPRECATED: use `run` instead
|
|
||||||
"""
|
|
||||||
if callback is None:
|
|
||||||
callback = worker.get_progress_callback()
|
|
||||||
else:
|
|
||||||
callback = ChainProgress.from_progress(callback)
|
|
||||||
|
|
||||||
start = monotonic()
|
def outputs(
|
||||||
|
self,
|
||||||
if len(sources) > 0:
|
_params: ImageParams,
|
||||||
logger.info(
|
sources: int,
|
||||||
"running pipeline on %s source images",
|
) -> int:
|
||||||
len(sources),
|
return sources
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("running pipeline without source images")
|
|
||||||
|
|
||||||
stage_sources = sources
|
|
||||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
|
||||||
name = stage_params.name or stage_pipe.__class__.__name__
|
|
||||||
kwargs = stage_kwargs or {}
|
|
||||||
kwargs = {**pipeline_kwargs, **kwargs}
|
|
||||||
logger.debug(
|
|
||||||
"running stage %s with %s source images, parameters: %s",
|
|
||||||
name,
|
|
||||||
len(stage_sources) - stage_sources.count(None),
|
|
||||||
kwargs.keys(),
|
|
||||||
)
|
|
||||||
|
|
||||||
per_stage_params = params
|
|
||||||
if "params" in kwargs:
|
|
||||||
per_stage_params = kwargs["params"]
|
|
||||||
kwargs.pop("params")
|
|
||||||
|
|
||||||
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
|
||||||
must_tile = any(
|
|
||||||
[
|
|
||||||
needs_tile(
|
|
||||||
stage_pipe.max_tile,
|
|
||||||
stage_params.tile_size,
|
|
||||||
size=kwargs.get("size", None),
|
|
||||||
source=source,
|
|
||||||
)
|
|
||||||
for source in stage_sources
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
tile = stage_params.tile_size
|
|
||||||
if stage_pipe.max_tile > 0:
|
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
|
||||||
|
|
||||||
if stage_sources or must_tile:
|
|
||||||
stage_outputs = []
|
|
||||||
for source in stage_sources:
|
|
||||||
logger.info(
|
|
||||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
|
||||||
tile,
|
|
||||||
)
|
|
||||||
|
|
||||||
extra_tiles = []
|
|
||||||
|
|
||||||
def stage_tile(
|
|
||||||
source_tile: Image.Image,
|
|
||||||
tile_mask: Image.Image,
|
|
||||||
dims: Tuple[int, int, int],
|
|
||||||
) -> Image.Image:
|
|
||||||
for _i in range(worker.retries):
|
|
||||||
try:
|
|
||||||
output_tile = stage_pipe.run(
|
|
||||||
worker,
|
|
||||||
server,
|
|
||||||
stage_params,
|
|
||||||
per_stage_params,
|
|
||||||
[source_tile],
|
|
||||||
tile_mask=tile_mask,
|
|
||||||
callback=callback,
|
|
||||||
dims=dims,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(output_tile) > 1:
|
|
||||||
while len(extra_tiles) < len(output_tile):
|
|
||||||
extra_tiles.append([])
|
|
||||||
|
|
||||||
for tile, layer in zip(output_tile, extra_tiles):
|
|
||||||
layer.append((tile, dims))
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-tile.png", output_tile[0])
|
|
||||||
|
|
||||||
return output_tile[0]
|
|
||||||
except Exception:
|
|
||||||
worker.retries = worker.retries - 1
|
|
||||||
logger.exception(
|
|
||||||
"error while running stage pipeline for tile, %s retries left",
|
|
||||||
worker.retries,
|
|
||||||
)
|
|
||||||
server.cache.clear()
|
|
||||||
run_gc([worker.get_device()])
|
|
||||||
|
|
||||||
raise RetryException("exhausted retries on tile")
|
|
||||||
|
|
||||||
output = process_tile_order(
|
|
||||||
stage_params.tile_order,
|
|
||||||
source,
|
|
||||||
tile,
|
|
||||||
stage_params.outscale,
|
|
||||||
[stage_tile],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
stage_outputs.append(output)
|
|
||||||
|
|
||||||
if len(extra_tiles) > 1:
|
|
||||||
for layer in extra_tiles:
|
|
||||||
layer_output = Image.new("RGB", output.size)
|
|
||||||
for layer_tile, dims in layer:
|
|
||||||
layer_output.paste(layer_tile, (dims[0], dims[1]))
|
|
||||||
|
|
||||||
stage_outputs.append(layer_output)
|
|
||||||
|
|
||||||
stage_sources = stage_outputs
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"image does not contain sources and is within tile size of %s, running stage",
|
|
||||||
tile,
|
|
||||||
)
|
|
||||||
for i in range(worker.retries):
|
|
||||||
try:
|
|
||||||
stage_outputs = stage_pipe.run(
|
|
||||||
worker,
|
|
||||||
server,
|
|
||||||
stage_params,
|
|
||||||
per_stage_params,
|
|
||||||
stage_sources,
|
|
||||||
callback=callback,
|
|
||||||
dims=(0, 0, tile),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
|
||||||
# does not like, so it throws
|
|
||||||
stage_sources = stage_outputs
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
worker.retries = worker.retries - 1
|
|
||||||
logger.exception(
|
|
||||||
"error while running stage pipeline, %s retries left",
|
|
||||||
worker.retries,
|
|
||||||
)
|
|
||||||
server.cache.clear()
|
|
||||||
run_gc([worker.get_device()])
|
|
||||||
|
|
||||||
if worker.retries <= 0:
|
|
||||||
raise RetryException("exhausted retries on stage")
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"finished stage %s with %s results",
|
|
||||||
name,
|
|
||||||
len(stage_sources),
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-stage.png", stage_sources[0])
|
|
||||||
|
|
||||||
end = monotonic()
|
|
||||||
duration = timedelta(seconds=(end - start))
|
|
||||||
logger.info(
|
|
||||||
"finished pipeline in %s with %s results",
|
|
||||||
duration,
|
|
||||||
len(stage_sources),
|
|
||||||
)
|
|
||||||
return stage_sources
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..chain.base import ChainPipeline
|
from .pipeline import ChainPipeline
|
||||||
from ..chain.blend_img2img import BlendImg2ImgStage
|
from ..chain.blend_img2img import BlendImg2ImgStage
|
||||||
from ..chain.upscale import stage_upscale_correction
|
from ..chain.upscale import stage_upscale_correction
|
||||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||||
|
|
|
@ -7,7 +7,7 @@ from ..output import save_image
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,283 @@
|
||||||
|
from datetime import timedelta
|
||||||
|
from logging import getLogger
|
||||||
|
from time import monotonic
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..errors import RetryException
|
||||||
|
from ..output import save_image
|
||||||
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..utils import is_debug, run_gc
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .tile import needs_tile, process_tile_order
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||||
|
|
||||||
|
|
||||||
|
class ChainProgress:
|
||||||
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
|
self.parent = parent
|
||||||
|
self.step = start
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||||
|
if step < self.step:
|
||||||
|
# accumulate on resets
|
||||||
|
self.total += self.step
|
||||||
|
|
||||||
|
self.step = step
|
||||||
|
self.parent(self.get_total(), timestep, latents)
|
||||||
|
|
||||||
|
def get_total(self) -> int:
|
||||||
|
return self.step + self.total
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_progress(cls, parent: ProgressCallback):
|
||||||
|
start = parent.step if hasattr(parent, "step") else 0
|
||||||
|
return ChainProgress(parent, start=start)
|
||||||
|
|
||||||
|
|
||||||
|
class ChainPipeline:
|
||||||
|
"""
|
||||||
|
Run many stages in series, passing the image results from each to the next, and processing
|
||||||
|
tiles as needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stages: Optional[List[PipelineStage]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new pipeline that will run the given stages.
|
||||||
|
"""
|
||||||
|
self.stages = list(stages or [])
|
||||||
|
|
||||||
|
def append(self, stage: Optional[PipelineStage]):
|
||||||
|
"""
|
||||||
|
Append an additional stage to this pipeline.
|
||||||
|
|
||||||
|
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
|
||||||
|
assemble the stage from loose arguments.
|
||||||
|
"""
|
||||||
|
if stage is not None:
|
||||||
|
self.stages.append(stage)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
worker: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: List[Image.Image],
|
||||||
|
callback: Optional[ProgressCallback],
|
||||||
|
**kwargs
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
return self(
|
||||||
|
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||||
|
self.stages.append((callback, params, kwargs))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def steps(self, params: ImageParams, size: Size):
|
||||||
|
steps = 0
|
||||||
|
for callback, _params, kwargs in self.stages:
|
||||||
|
steps += callback.steps(kwargs.get("params", params), size)
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
def outputs(self, params: ImageParams, sources: int):
|
||||||
|
outputs = sources
|
||||||
|
for callback, _params, kwargs in self.stages:
|
||||||
|
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
worker: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: List[Image.Image],
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
**pipeline_kwargs
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
"""
|
||||||
|
DEPRECATED: use `run` instead
|
||||||
|
"""
|
||||||
|
if callback is None:
|
||||||
|
callback = worker.get_progress_callback()
|
||||||
|
else:
|
||||||
|
callback = ChainProgress.from_progress(callback)
|
||||||
|
|
||||||
|
start = monotonic()
|
||||||
|
|
||||||
|
if len(sources) > 0:
|
||||||
|
logger.info(
|
||||||
|
"running pipeline on %s source images",
|
||||||
|
len(sources),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("running pipeline without source images")
|
||||||
|
|
||||||
|
stage_sources = sources
|
||||||
|
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||||
|
name = stage_params.name or stage_pipe.__class__.__name__
|
||||||
|
kwargs = stage_kwargs or {}
|
||||||
|
kwargs = {**pipeline_kwargs, **kwargs}
|
||||||
|
logger.debug(
|
||||||
|
"running stage %s with %s source images, parameters: %s",
|
||||||
|
name,
|
||||||
|
len(stage_sources) - stage_sources.count(None),
|
||||||
|
kwargs.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
per_stage_params = params
|
||||||
|
if "params" in kwargs:
|
||||||
|
per_stage_params = kwargs["params"]
|
||||||
|
kwargs.pop("params")
|
||||||
|
|
||||||
|
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
||||||
|
must_tile = any(
|
||||||
|
[
|
||||||
|
needs_tile(
|
||||||
|
stage_pipe.max_tile,
|
||||||
|
stage_params.tile_size,
|
||||||
|
size=kwargs.get("size", None),
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
for source in stage_sources
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
tile = stage_params.tile_size
|
||||||
|
if stage_pipe.max_tile > 0:
|
||||||
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
|
if stage_sources or must_tile:
|
||||||
|
stage_outputs = []
|
||||||
|
for source in stage_sources:
|
||||||
|
logger.info(
|
||||||
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
|
tile,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_tiles = []
|
||||||
|
|
||||||
|
def stage_tile(
|
||||||
|
source_tile: Image.Image,
|
||||||
|
tile_mask: Image.Image,
|
||||||
|
dims: Tuple[int, int, int],
|
||||||
|
) -> Image.Image:
|
||||||
|
for _i in range(worker.retries):
|
||||||
|
try:
|
||||||
|
output_tile = stage_pipe.run(
|
||||||
|
worker,
|
||||||
|
server,
|
||||||
|
stage_params,
|
||||||
|
per_stage_params,
|
||||||
|
[source_tile],
|
||||||
|
tile_mask=tile_mask,
|
||||||
|
callback=callback,
|
||||||
|
dims=dims,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(output_tile) > 1:
|
||||||
|
while len(extra_tiles) < len(output_tile):
|
||||||
|
extra_tiles.append([])
|
||||||
|
|
||||||
|
for tile, layer in zip(output_tile, extra_tiles):
|
||||||
|
layer.append((tile, dims))
|
||||||
|
|
||||||
|
if is_debug():
|
||||||
|
save_image(server, "last-tile.png", output_tile[0])
|
||||||
|
|
||||||
|
return output_tile[0]
|
||||||
|
except Exception:
|
||||||
|
worker.retries = worker.retries - 1
|
||||||
|
logger.exception(
|
||||||
|
"error while running stage pipeline for tile, %s retries left",
|
||||||
|
worker.retries,
|
||||||
|
)
|
||||||
|
server.cache.clear()
|
||||||
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
|
raise RetryException("exhausted retries on tile")
|
||||||
|
|
||||||
|
output = process_tile_order(
|
||||||
|
stage_params.tile_order,
|
||||||
|
source,
|
||||||
|
tile,
|
||||||
|
stage_params.outscale,
|
||||||
|
[stage_tile],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_outputs.append(output)
|
||||||
|
|
||||||
|
if len(extra_tiles) > 1:
|
||||||
|
for layer in extra_tiles:
|
||||||
|
layer_output = Image.new("RGB", output.size)
|
||||||
|
for layer_tile, dims in layer:
|
||||||
|
layer_output.paste(layer_tile, (dims[0], dims[1]))
|
||||||
|
|
||||||
|
stage_outputs.append(layer_output)
|
||||||
|
|
||||||
|
stage_sources = stage_outputs
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"image does not contain sources and is within tile size of %s, running stage",
|
||||||
|
tile,
|
||||||
|
)
|
||||||
|
for i in range(worker.retries):
|
||||||
|
try:
|
||||||
|
stage_outputs = stage_pipe.run(
|
||||||
|
worker,
|
||||||
|
server,
|
||||||
|
stage_params,
|
||||||
|
per_stage_params,
|
||||||
|
stage_sources,
|
||||||
|
callback=callback,
|
||||||
|
dims=(0, 0, tile),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
||||||
|
# does not like, so it throws
|
||||||
|
stage_sources = stage_outputs
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
worker.retries = worker.retries - 1
|
||||||
|
logger.exception(
|
||||||
|
"error while running stage pipeline, %s retries left",
|
||||||
|
worker.retries,
|
||||||
|
)
|
||||||
|
server.cache.clear()
|
||||||
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
|
if worker.retries <= 0:
|
||||||
|
raise RetryException("exhausted retries on stage")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"finished stage %s with %s results",
|
||||||
|
name,
|
||||||
|
len(stage_sources),
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_debug():
|
||||||
|
save_image(server, "last-stage.png", stage_sources[0])
|
||||||
|
|
||||||
|
end = monotonic()
|
||||||
|
duration = timedelta(seconds=(end - start))
|
||||||
|
logger.info(
|
||||||
|
"finished pipeline in %s with %s results",
|
||||||
|
duration,
|
||||||
|
len(stage_sources),
|
||||||
|
)
|
||||||
|
return stage_sources
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
from PIL.Image import Image, fromarray
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class StageResult:
|
||||||
|
"""
|
||||||
|
Chain pipeline stage result.
|
||||||
|
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
||||||
|
"""
|
||||||
|
arrays: Optional[List[np.ndarray]]
|
||||||
|
images: Optional[List[Image]]
|
||||||
|
|
||||||
|
def __init__(self, arrays = None, images = None) -> None:
|
||||||
|
if arrays is not None and images is not None:
|
||||||
|
raise ValueError("stages must only return one type of result")
|
||||||
|
|
||||||
|
self.arrays = arrays
|
||||||
|
self.images = images
|
||||||
|
|
||||||
|
def as_numpy(self) -> List[np.ndarray]:
|
||||||
|
if self.arrays is not None:
|
||||||
|
return self.arrays
|
||||||
|
|
||||||
|
return [np.array(i) for i in self.images]
|
||||||
|
|
||||||
|
def as_image(self) -> List[Image]:
|
||||||
|
if self.images is not None:
|
||||||
|
return self.images
|
||||||
|
|
||||||
|
return [fromarray(i) for i in self.arrays]
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from ..diffusers.utils import (
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
|
||||||
from ..server.context import ServerContext
|
|
||||||
from ..worker.context import WorkerContext
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStage:
|
|
||||||
max_tile = SizeChart.auto
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
_worker: WorkerContext,
|
|
||||||
_server: ServerContext,
|
|
||||||
_stage: StageParams,
|
|
||||||
_params: ImageParams,
|
|
||||||
_sources: List[Image.Image],
|
|
||||||
*args,
|
|
||||||
stage_source: Optional[Image.Image] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> List[Image.Image]:
|
|
||||||
raise NotImplementedError() # noqa
|
|
||||||
|
|
||||||
def steps(
|
|
||||||
self,
|
|
||||||
_params: ImageParams,
|
|
||||||
_size: Size,
|
|
||||||
) -> int:
|
|
||||||
return 1 # noqa
|
|
||||||
|
|
||||||
def outputs(
|
|
||||||
self,
|
|
||||||
_params: ImageParams,
|
|
||||||
sources: int,
|
|
||||||
) -> int:
|
|
||||||
return sources
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from .base import BaseStage
|
||||||
|
from .blend_denoise import BlendDenoiseStage
|
||||||
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
|
from .blend_grid import BlendGridStage
|
||||||
|
from .blend_linear import BlendLinearStage
|
||||||
|
from .blend_mask import BlendMaskStage
|
||||||
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
|
from .correct_gfpgan import CorrectGFPGANStage
|
||||||
|
from .persist_disk import PersistDiskStage
|
||||||
|
from .persist_s3 import PersistS3Stage
|
||||||
|
from .reduce_crop import ReduceCropStage
|
||||||
|
from .reduce_thumbnail import ReduceThumbnailStage
|
||||||
|
from .source_noise import SourceNoiseStage
|
||||||
|
from .source_s3 import SourceS3Stage
|
||||||
|
from .source_txt2img import SourceTxt2ImgStage
|
||||||
|
from .source_url import SourceURLStage
|
||||||
|
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||||
|
from .upscale_highres import UpscaleHighresStage
|
||||||
|
from .upscale_outpaint import UpscaleOutpaintStage
|
||||||
|
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||||
|
from .upscale_simple import UpscaleSimpleStage
|
||||||
|
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||||
|
from .upscale_swinir import UpscaleSwinIRStage
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
CHAIN_STAGES = {
|
||||||
|
"blend-denoise": BlendDenoiseStage,
|
||||||
|
"blend-img2img": BlendImg2ImgStage,
|
||||||
|
"blend-inpaint": UpscaleOutpaintStage,
|
||||||
|
"blend-grid": BlendGridStage,
|
||||||
|
"blend-linear": BlendLinearStage,
|
||||||
|
"blend-mask": BlendMaskStage,
|
||||||
|
"correct-codeformer": CorrectCodeformerStage,
|
||||||
|
"correct-gfpgan": CorrectGFPGANStage,
|
||||||
|
"persist-disk": PersistDiskStage,
|
||||||
|
"persist-s3": PersistS3Stage,
|
||||||
|
"reduce-crop": ReduceCropStage,
|
||||||
|
"reduce-thumbnail": ReduceThumbnailStage,
|
||||||
|
"source-noise": SourceNoiseStage,
|
||||||
|
"source-s3": SourceS3Stage,
|
||||||
|
"source-txt2img": SourceTxt2ImgStage,
|
||||||
|
"source-url": SourceURLStage,
|
||||||
|
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||||
|
"upscale-highres": UpscaleHighresStage,
|
||||||
|
"upscale-outpaint": UpscaleOutpaintStage,
|
||||||
|
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||||
|
"upscale-simple": UpscaleSimpleStage,
|
||||||
|
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
||||||
|
"upscale-swinir": UpscaleSwinIRStage,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_stage(name: str, stage: BaseStage) -> bool:
|
||||||
|
global CHAIN_STAGES
|
||||||
|
|
||||||
|
if name in CHAIN_STAGES:
|
||||||
|
logger.warning("cannot replace stage: %s", name)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
CHAIN_STAGES[name] = stage
|
||||||
|
return True
|
|
@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from ..worker.context import ProgressCallback
|
from ..worker.context import ProgressCallback
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..diffusers.utils import encode_prompt, parse_prompt
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ class WorkerContext:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_progress_callback(self) -> ProgressCallback:
|
def get_progress_callback(self) -> ProgressCallback:
|
||||||
from ..chain.base import ChainProgress
|
from ..chain.pipeline import ChainProgress
|
||||||
|
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
on_progress.step = step
|
on_progress.step = step
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from onnx_web.chain.base import ChainProgress
|
from onnx_web.chain.pipeline import ChainProgress
|
||||||
|
|
||||||
|
|
||||||
class ChainProgressTests(unittest.TestCase):
|
class ChainProgressTests(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in New Issue