1
0
Fork 0

feat(api): add chain pipeline stage result type

This commit is contained in:
Sean Sube 2023-11-18 17:18:23 -06:00
parent c8dd85e798
commit d52c68d607
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
31 changed files with 433 additions and 384 deletions

View File

@ -1,49 +1,2 @@
from .base import ChainPipeline, PipelineStage, StageParams
from .blend_denoise import BlendDenoiseStage
from .blend_img2img import BlendImg2ImgStage
from .blend_grid import BlendGridStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = {
"blend-denoise": BlendDenoiseStage,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}
from .pipeline import ChainPipeline, PipelineStage, StageParams
from .stages import *

View File

@ -1,283 +1,39 @@
from datetime import timedelta
from logging import getLogger
from time import monotonic
from typing import Any, List, Optional, Tuple
from typing import List, Optional
from PIL import Image
from ..errors import RetryException
from ..output import save_image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
from .result import StageResult
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server.context import ServerContext
from ..worker.context import WorkerContext
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
"""
if stage is not None:
self.stages.append(stage)
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> List[Image.Image]:
return self(
worker, server, params, sources=sources, callback=callback, **kwargs
)
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
_sources: List[Image.Image],
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:
raise NotImplementedError() # noqa
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
def steps(self, params: ImageParams, size: Size):
steps = 0
for callback, _params, kwargs in self.stages:
steps += callback.steps(kwargs.get("params", params), size)
return steps
def outputs(self, params: ImageParams, sources: int):
outputs = sources
for callback, _params, kwargs in self.stages:
outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs
def __call__(
def steps(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> List[Image.Image]:
"""
DEPRECATED: use `run` instead
"""
if callback is None:
callback = worker.get_progress_callback()
else:
callback = ChainProgress.from_progress(callback)
_params: ImageParams,
_size: Size,
) -> int:
return 1 # noqa
start = monotonic()
if len(sources) > 0:
logger.info(
"running pipeline on %s source images",
len(sources),
)
else:
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources) - stage_sources.count(None),
kwargs.keys(),
)
per_stage_params = params
if "params" in kwargs:
per_stage_params = kwargs["params"]
kwargs.pop("params")
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources
]
)
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
if stage_sources or must_tile:
stage_outputs = []
for source in stage_sources:
logger.info(
"image contains sources or is larger than tile size of %s, tiling stage",
tile,
)
extra_tiles = []
def stage_tile(
source_tile: Image.Image,
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> Image.Image:
for _i in range(worker.retries):
try:
output_tile = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
[source_tile],
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
)
if len(output_tile) > 1:
while len(extra_tiles) < len(output_tile):
extra_tiles.append([])
for tile, layer in zip(output_tile, extra_tiles):
layer.append((tile, dims))
if is_debug():
save_image(server, "last-tile.png", output_tile[0])
return output_tile[0]
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline for tile, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
raise RetryException("exhausted retries on tile")
output = process_tile_order(
stage_params.tile_order,
source,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_outputs.append(output)
if len(extra_tiles) > 1:
for layer in extra_tiles:
layer_output = Image.new("RGB", output.size)
for layer_tile, dims in layer:
layer_output.paste(layer_tile, (dims[0], dims[1]))
stage_outputs.append(layer_output)
stage_sources = stage_outputs
else:
logger.debug(
"image does not contain sources and is within tile size of %s, running stage",
tile,
)
for i in range(worker.retries):
try:
stage_outputs = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
stage_sources,
callback=callback,
dims=(0, 0, tile),
**kwargs,
)
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws
stage_sources = stage_outputs
break
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
if worker.retries <= 0:
raise RetryException("exhausted retries on stage")
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
)
if is_debug():
save_image(server, "last-stage.png", stage_sources[0])
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info(
"finished pipeline in %s with %s results",
duration,
len(stage_sources),
)
return stage_sources
def outputs(
self,
_params: ImageParams,
sources: int,
) -> int:
return sources

View File

@ -8,7 +8,7 @@ from PIL import Image
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -10,7 +10,7 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -8,7 +8,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -9,7 +9,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -1,7 +1,7 @@
from logging import getLogger
from typing import Optional
from ..chain.base import ChainPipeline
from .pipeline import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage

View File

@ -7,7 +7,7 @@ from ..output import save_image
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -8,7 +8,7 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -0,0 +1,283 @@
from datetime import timedelta
from logging import getLogger
from time import monotonic
from typing import Any, List, Optional, Tuple
from PIL import Image
from ..errors import RetryException
from ..output import save_image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
"""
if stage is not None:
self.stages.append(stage)
def run(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> List[Image.Image]:
return self(
worker, server, params, sources=sources, callback=callback, **kwargs
)
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
def steps(self, params: ImageParams, size: Size):
steps = 0
for callback, _params, kwargs in self.stages:
steps += callback.steps(kwargs.get("params", params), size)
return steps
def outputs(self, params: ImageParams, sources: int):
outputs = sources
for callback, _params, kwargs in self.stages:
outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs
def __call__(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> List[Image.Image]:
"""
DEPRECATED: use `run` instead
"""
if callback is None:
callback = worker.get_progress_callback()
else:
callback = ChainProgress.from_progress(callback)
start = monotonic()
if len(sources) > 0:
logger.info(
"running pipeline on %s source images",
len(sources),
)
else:
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources) - stage_sources.count(None),
kwargs.keys(),
)
per_stage_params = params
if "params" in kwargs:
per_stage_params = kwargs["params"]
kwargs.pop("params")
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources
]
)
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
if stage_sources or must_tile:
stage_outputs = []
for source in stage_sources:
logger.info(
"image contains sources or is larger than tile size of %s, tiling stage",
tile,
)
extra_tiles = []
def stage_tile(
source_tile: Image.Image,
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> Image.Image:
for _i in range(worker.retries):
try:
output_tile = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
[source_tile],
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
)
if len(output_tile) > 1:
while len(extra_tiles) < len(output_tile):
extra_tiles.append([])
for tile, layer in zip(output_tile, extra_tiles):
layer.append((tile, dims))
if is_debug():
save_image(server, "last-tile.png", output_tile[0])
return output_tile[0]
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline for tile, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
raise RetryException("exhausted retries on tile")
output = process_tile_order(
stage_params.tile_order,
source,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_outputs.append(output)
if len(extra_tiles) > 1:
for layer in extra_tiles:
layer_output = Image.new("RGB", output.size)
for layer_tile, dims in layer:
layer_output.paste(layer_tile, (dims[0], dims[1]))
stage_outputs.append(layer_output)
stage_sources = stage_outputs
else:
logger.debug(
"image does not contain sources and is within tile size of %s, running stage",
tile,
)
for i in range(worker.retries):
try:
stage_outputs = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
stage_sources,
callback=callback,
dims=(0, 0, tile),
**kwargs,
)
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws
stage_sources = stage_outputs
break
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
if worker.retries <= 0:
raise RetryException("exhausted retries on stage")
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
)
if is_debug():
save_image(server, "last-stage.png", stage_sources[0])
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info(
"finished pipeline in %s with %s results",
duration,
len(stage_sources),
)
return stage_sources

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -0,0 +1,31 @@
from PIL.Image import Image, fromarray
from typing import List, Optional
import numpy as np
class StageResult:
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image]]
def __init__(self, arrays = None, images = None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")
self.arrays = arrays
self.images = images
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
return [np.array(i) for i in self.images]
def as_image(self) -> List[Image]:
if self.images is not None:
return self.images
return [fromarray(i) for i in self.arrays]

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -8,7 +8,7 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -18,7 +18,7 @@ from ..diffusers.utils import (
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -8,7 +8,7 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -1,38 +0,0 @@
from typing import List, Optional
from PIL import Image
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server.context import ServerContext
from ..worker.context import WorkerContext
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
_sources: List[Image.Image],
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
raise NotImplementedError() # noqa
def steps(
self,
_params: ImageParams,
_size: Size,
) -> int:
return 1 # noqa
def outputs(
self,
_params: ImageParams,
sources: int,
) -> int:
return sources

View File

@ -0,0 +1,64 @@
from logging import getLogger
from .base import BaseStage
from .blend_denoise import BlendDenoiseStage
from .blend_img2img import BlendImg2ImgStage
from .blend_grid import BlendGridStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
logger = getLogger(__name__)
CHAIN_STAGES = {
"blend-denoise": BlendDenoiseStage,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}
def add_stage(name: str, stage: BaseStage) -> bool:
global CHAIN_STAGES
if name in CHAIN_STAGES:
logger.warning("cannot replace stage: %s", name)
return False
else:
CHAIN_STAGES[name] = stage
return True

View File

@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -8,7 +8,7 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -18,7 +18,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -6,7 +6,7 @@ from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -10,7 +10,7 @@ from ..diffusers.utils import encode_prompt, parse_prompt
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -10,7 +10,7 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
logger = getLogger(__name__)

View File

@ -86,7 +86,7 @@ class WorkerContext:
return 0
def get_progress_callback(self) -> ProgressCallback:
from ..chain.base import ChainProgress
from ..chain.pipeline import ChainProgress
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step

View File

@ -1,6 +1,6 @@
import unittest
from onnx_web.chain.base import ChainProgress
from onnx_web.chain.pipeline import ChainProgress
class ChainProgressTests(unittest.TestCase):