1
0
Fork 0

feat(api): make chain stages into classes with max tile size and step count estimate

This commit is contained in:
Sean Sube 2023-07-01 07:10:53 -05:00
parent 5e1b70091c
commit 2913cd0382
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
29 changed files with 1121 additions and 1019 deletions

View File

@ -17,7 +17,7 @@ from .diffusers.run import (
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import stage_upscale_correction
from .chain.upscale import stage_upscale_correction
from .image.utils import (
expand_image,
)

View File

@ -1,44 +1,44 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .blend_linear import blend_linear
from .blend_mask import blend_mask
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk
from .persist_s3 import persist_s3
from .reduce_crop import reduce_crop
from .reduce_thumbnail import reduce_thumbnail
from .source_noise import source_noise
from .source_s3 import source_s3
from .source_txt2img import source_txt2img
from .source_url import source_url
from .upscale_bsrgan import upscale_bsrgan
from .upscale_highres import upscale_highres
from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion
from .upscale_swinir import upscale_swinir
from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import BlendInpaintStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = {
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"blend-linear": blend_linear,
"blend-mask": blend_mask,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-s3": source_s3,
"source-txt2img": source_txt2img,
"source-url": source_url,
"upscale-bsrgan": upscale_bsrgan,
"upscale-highres": upscale_highres,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
"upscale-swinir": upscale_swinir,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": BlendInpaintStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}

View File

@ -10,6 +10,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .utils import process_tile_order
logger = getLogger(__name__)
@ -35,7 +36,7 @@ class StageCallback(Protocol):
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
@ -131,7 +132,7 @@ class ChainPipeline:
logger.info("running pipeline without source image")
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
@ -158,7 +159,7 @@ class ChainPipeline:
)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(
tile = stage_pipe.run(
job,
server,
stage_params,
@ -182,7 +183,7 @@ class ChainPipeline:
)
else:
logger.debug("image within tile size, running stage")
image = stage_pipe(
image = stage_pipe.run(
job,
server,
stage_params,

View File

@ -14,7 +14,9 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_img2img(
class BlendImg2ImgStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
@ -69,7 +71,9 @@ def blend_img2img(
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)

View File

@ -18,7 +18,9 @@ from .utils import process_tile_order
logger = getLogger(__name__)
def blend_inpaint(
class BlendInpaintStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
@ -115,7 +117,12 @@ def blend_inpaint(
return result.images[0]
output = process_tile_order(
stage.tile_order, source, SizeChart.auto, 1, [outpaint], overlap=params.overlap
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
logger.info("final output image size: %s", output.size)

View File

@ -10,7 +10,9 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_linear(
class BlendLinearStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,

View File

@ -12,7 +12,9 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_mask(
class BlendMaskStage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,

View File

@ -9,10 +9,10 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
device = "cpu"
def correct_codeformer(
class CorrectCodeformerStage:
def run(
self,
job: WorkerContext,
_server: ServerContext,
_stage: StageParams,

View File

@ -13,7 +13,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def load_gfpgan(
class CorrectGFPGANStage:
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
@ -48,8 +50,8 @@ def load_gfpgan(
return gfpgan
def correct_gfpgan(
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
@ -69,7 +71,7 @@ def correct_gfpgan(
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
gfpgan = load_gfpgan(server, stage, upscale, device)
gfpgan = self.load(server, stage, upscale, device)
output = np.array(source)
_, _, output = gfpgan.enhance(

View File

@ -10,7 +10,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def persist_disk(
class PersistDiskStage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,

View File

@ -12,7 +12,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def persist_s3(
class PersistS3Stage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,

View File

@ -10,7 +10,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def reduce_crop(
class ReduceCropStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
@ -25,5 +27,7 @@ def reduce_crop(
source = stage_source or source
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image

View File

@ -9,7 +9,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def reduce_thumbnail(
class ReduceThumbnailStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
@ -25,5 +27,7 @@ def reduce_thumbnail(
image = image.thumbnail((size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image

View File

@ -10,7 +10,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def source_noise(
class SourceNoiseStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
@ -26,7 +28,9 @@ def source_noise(
logger.info("generating image from noise source")
if source is not None:
logger.warn("a source image was passed to a noise stage, but will be discarded")
logger.warn(
"a source image was passed to a noise stage, but will be discarded"
)
output = noise_source(source, (size.width, size.height), (0, 0))

View File

@ -12,7 +12,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def source_s3(
class SourceS3Stage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,

View File

@ -14,7 +14,9 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def source_txt2img(
class SourceTxt2ImgStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
@ -65,7 +67,9 @@ def source_txt2img(
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)

View File

@ -11,7 +11,9 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def source_url(
class SourceURLStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,

View File

@ -0,0 +1,31 @@
from typing import Optional
from PIL import Image
from onnx_web.params import ImageParams, Size, SizeChart, StageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
raise NotImplementedError()
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
raise NotImplementedError()

View File

@ -1,14 +1,14 @@
from logging import getLogger
from typing import List, Optional, Tuple
from ..chain import ChainPipeline, PipelineStage
from ..chain.correct_codeformer import correct_codeformer
from ..chain.correct_gfpgan import correct_gfpgan
from ..chain.upscale_bsrgan import upscale_bsrgan
from ..chain.upscale_resrgan import upscale_resrgan
from ..chain.upscale_stable_diffusion import upscale_stable_diffusion
from ..chain.upscale_swinir import upscale_swinir
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from . import ChainPipeline, PipelineStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
logger = getLogger(__name__)
@ -72,23 +72,23 @@ def stage_upscale_correction(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts)
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts)
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts)
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_swinir, swinir_params, upscale_opts)
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
@ -98,9 +98,9 @@ def stage_upscale_correction(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
correct_stage = (correct_codeformer, face_params, upscale_opts)
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
elif "gfpgan" in upscale.correction_model:
correct_stage = (correct_gfpgan, face_params, upscale_opts)
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
else:
logger.warn("unknown correction model: %s", upscale.correction_model)

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
@ -14,7 +14,11 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def load_bsrgan(
class UpscaleBSRGANStage:
max_tile = 64
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
@ -43,8 +47,8 @@ def load_bsrgan(
return pipe
def upscale_bsrgan(
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
@ -64,7 +68,7 @@ def upscale_bsrgan(
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
bsrgan = load_bsrgan(server, stage, upscale, device)
bsrgan = self.load(server, stage, upscale, device)
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
@ -77,7 +81,12 @@ def upscale_bsrgan(
scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
)
)
logger.trace("BSRGAN output shape: %s", dest.shape)
@ -116,3 +125,10 @@ def upscale_bsrgan(
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
return output
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
return size.width // self.max_tile * size.height // self.max_tile

View File

@ -3,18 +3,19 @@ from typing import Any, Optional
from PIL import Image
from ..chain.base import ChainPipeline
from ..chain.blend_img2img import blend_img2img
from ..diffusers.upscale import stage_upscale_correction
from ..chain import BlendImg2ImgStage, ChainPipeline
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .upscale import stage_upscale_correction
logger = getLogger(__name__)
def upscale_highres(
class UpscaleHighresStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
@ -57,7 +58,7 @@ def upscale_highres(
)
chain.stage(
blend_img2img,
BlendImg2ImgStage(),
StageParams(),
overlap=params.overlap,
strength=highres.strength,

View File

@ -18,7 +18,9 @@ from .utils import complete_tile, process_tile_grid, process_tile_order
logger = getLogger(__name__)
def upscale_outpaint(
class UpscaleOutpaintStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,

View File

@ -16,8 +16,9 @@ logger = getLogger(__name__)
TAG_X4_V3 = "real-esrgan-x4-v3"
def load_resrgan(
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
class UpscaleRealESRGANStage:
def load(
self, server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
# TODO: rewrite and remove
@ -69,8 +70,8 @@ def load_resrgan(
return upsampler
def upscale_resrgan(
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
@ -85,7 +86,7 @@ def upscale_resrgan(
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)

View File

@ -14,7 +14,9 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def upscale_stable_diffusion(
class UpscaleStableDiffusionStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,

View File

@ -14,7 +14,11 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def load_swinir(
class UpscaleSwinIRStage:
max_tile = 64
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
@ -43,8 +47,8 @@ def load_swinir(
return pipe
def upscale_swinir(
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
@ -64,7 +68,7 @@ def upscale_swinir(
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device()
swinir = load_swinir(server, stage, upscale, device)
swinir = self.load(server, stage, upscale, device)
# TODO: add support for other sizes
tile_size = (64, 64)
@ -79,7 +83,12 @@ def upscale_swinir(
scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
)
)
logger.info("SwinIR output shape: %s", dest.shape)

View File

@ -4,13 +4,14 @@ from typing import Any, List, Optional
from PIL import Image
from ..chain import (
blend_img2img,
blend_mask,
source_txt2img,
upscale_highres,
upscale_outpaint,
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
SourceTxt2ImgStage,
UpscaleHighresStage,
UpscaleOutpaintStage,
)
from ..chain.base import ChainPipeline
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..output import save_image
from ..params import (
Border,
@ -24,7 +25,6 @@ from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
from ..worker import WorkerContext
from .upscale import split_upscale, stage_upscale_correction
from .utils import parse_prompt
logger = getLogger(__name__)
@ -43,7 +43,7 @@ def run_txt2img_pipeline(
chain = ChainPipeline()
stage = StageParams()
chain.stage(
source_txt2img,
SourceTxt2ImgStage(),
stage,
size=size,
)
@ -61,7 +61,7 @@ def run_txt2img_pipeline(
# apply highres
for _i in range(highres.iterations):
chain.stage(
upscale_highres,
UpscaleHighresStage(),
StageParams(
outscale=highres.scale,
),
@ -125,7 +125,7 @@ def run_img2img_pipeline(
chain = ChainPipeline()
stage = StageParams()
chain.stage(
blend_img2img,
BlendImg2ImgStage(),
stage,
strength=strength,
)
@ -144,7 +144,7 @@ def run_img2img_pipeline(
if params.loopback > 0:
for _i in range(params.loopback):
chain.stage(
blend_img2img,
BlendImg2ImgStage(),
stage,
strength=strength,
)
@ -153,7 +153,7 @@ def run_img2img_pipeline(
if highres.iterations > 0:
for _i in range(highres.iterations):
chain.stage(
upscale_highres,
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
@ -223,7 +223,7 @@ def run_inpaint_pipeline(
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order)
chain.stage(
upscale_outpaint,
UpscaleOutpaintStage(),
stage,
border=border,
stage_mask=mask,
@ -234,7 +234,7 @@ def run_inpaint_pipeline(
# apply highres
chain.stage(
upscale_highres,
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
@ -300,7 +300,7 @@ def run_upscale_pipeline(
# apply highres
chain.stage(
upscale_highres,
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
@ -353,7 +353,7 @@ def run_blend_pipeline(
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams()
chain.stage(blend_mask, stage, stage_source=sources[1], stage_mask=mask)
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
# apply upscaling and correction
stage_upscale_correction(

View File

@ -1,6 +1,4 @@
from typing import Tuple, Union
from PIL import Image, ImageChops, ImageOps
from PIL import Image, ImageChops
from ..params import Border, Size
from .mask_filter import mask_filter_none

View File

@ -356,12 +356,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
callback = CHAIN_STAGES[stage_data.get("type")]
stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs)
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
stage = StageParams(
stage_data.get("name", callback.__name__),
stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
@ -399,7 +399,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))
pipeline.append((stage_class(), stage, kwargs))
logger.info("running chain pipeline with %s stages", len(pipeline.stages))