1
0
Fork 0

feat(api): make chain stages into classes with max tile size and step count estimate

This commit is contained in:
Sean Sube 2023-07-01 07:10:53 -05:00
parent 5e1b70091c
commit 2913cd0382
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
29 changed files with 1121 additions and 1019 deletions

View File

@ -17,7 +17,7 @@ from .diffusers.run import (
run_upscale_pipeline, run_upscale_pipeline,
) )
from .diffusers.stub_scheduler import StubScheduler from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import stage_upscale_correction from .chain.upscale import stage_upscale_correction
from .image.utils import ( from .image.utils import (
expand_image, expand_image,
) )

View File

@ -1,44 +1,44 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import blend_inpaint from .blend_inpaint import BlendInpaintStage
from .blend_linear import blend_linear from .blend_linear import BlendLinearStage
from .blend_mask import blend_mask from .blend_mask import BlendMaskStage
from .correct_codeformer import correct_codeformer from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import correct_gfpgan from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import persist_disk from .persist_disk import PersistDiskStage
from .persist_s3 import persist_s3 from .persist_s3 import PersistS3Stage
from .reduce_crop import reduce_crop from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import reduce_thumbnail from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import source_noise from .source_noise import SourceNoiseStage
from .source_s3 import source_s3 from .source_s3 import SourceS3Stage
from .source_txt2img import source_txt2img from .source_txt2img import SourceTxt2ImgStage
from .source_url import source_url from .source_url import SourceURLStage
from .upscale_bsrgan import upscale_bsrgan from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import upscale_highres from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import upscale_outpaint from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import upscale_resrgan from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_stable_diffusion import upscale_stable_diffusion from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import upscale_swinir from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = { CHAIN_STAGES = {
"blend-img2img": blend_img2img, "blend-img2img": BlendImg2ImgStage,
"blend-inpaint": blend_inpaint, "blend-inpaint": BlendInpaintStage,
"blend-linear": blend_linear, "blend-linear": BlendLinearStage,
"blend-mask": blend_mask, "blend-mask": BlendMaskStage,
"correct-codeformer": correct_codeformer, "correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": correct_gfpgan, "correct-gfpgan": CorrectGFPGANStage,
"persist-disk": persist_disk, "persist-disk": PersistDiskStage,
"persist-s3": persist_s3, "persist-s3": PersistS3Stage,
"reduce-crop": reduce_crop, "reduce-crop": ReduceCropStage,
"reduce-thumbnail": reduce_thumbnail, "reduce-thumbnail": ReduceThumbnailStage,
"source-noise": source_noise, "source-noise": SourceNoiseStage,
"source-s3": source_s3, "source-s3": SourceS3Stage,
"source-txt2img": source_txt2img, "source-txt2img": SourceTxt2ImgStage,
"source-url": source_url, "source-url": SourceURLStage,
"upscale-bsrgan": upscale_bsrgan, "upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": upscale_highres, "upscale-highres": UpscaleHighresStage,
"upscale-outpaint": upscale_outpaint, "upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": upscale_resrgan, "upscale-resrgan": UpscaleRealESRGANStage,
"upscale-stable-diffusion": upscale_stable_diffusion, "upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": upscale_swinir, "upscale-swinir": UpscaleSwinIRStage,
} }

View File

@ -10,6 +10,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .utils import process_tile_order from .utils import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
@ -35,7 +36,7 @@ class StageCallback(Protocol):
pass pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]] PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress: class ChainProgress:
@ -131,7 +132,7 @@ class ChainPipeline:
logger.info("running pipeline without source image") logger.info("running pipeline without source image")
for stage_pipe, stage_params, stage_kwargs in self.stages: for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__ name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {} kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs} kwargs = {**pipeline_kwargs, **kwargs}
@ -158,7 +159,7 @@ class ChainPipeline:
) )
def stage_tile(tile: Image.Image, _dims) -> Image.Image: def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe( tile = stage_pipe.run(
job, job,
server, server,
stage_params, stage_params,
@ -182,7 +183,7 @@ class ChainPipeline:
) )
else: else:
logger.debug("image within tile size, running stage") logger.debug("image within tile size, running stage")
image = stage_pipe( image = stage_pipe.run(
job, job,
server, server,
stage_params, stage_params,

View File

@ -14,77 +14,81 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def blend_img2img( class BlendImg2ImgStage:
job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, job: WorkerContext,
params: ImageParams, server: ServerContext,
source: Image.Image, _stage: StageParams,
*, params: ImageParams,
strength: float, source: Image.Image,
callback: Optional[ProgressCallback] = None, *,
stage_source: Optional[Image.Image] = None, strength: float,
**kwargs, callback: Optional[ProgressCallback] = None,
) -> Image.Image: stage_source: Optional[Image.Image] = None,
params = params.with_args(**kwargs) **kwargs,
source = stage_source or source ) -> Image.Image:
logger.info( params = params.with_args(**kwargs)
"blending image using img2img, %s steps: %s", params.steps, params.prompt source = stage_source or source
) logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
) )
output = result.images[0] prompt_pairs, loras, inversions = parse_prompt(params)
logger.info("final output image size: %sx%s", output.width, output.height) pipe_type = params.get_valid_pipeline("img2img")
return output pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -18,105 +18,112 @@ from .utils import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
def blend_inpaint( class BlendInpaintStage:
job: WorkerContext, def run(
server: ServerContext, self,
stage: StageParams, job: WorkerContext,
params: ImageParams, server: ServerContext,
source: Image.Image, stage: StageParams,
*, params: ImageParams,
expand: Border, source: Image.Image,
stage_source: Optional[Image.Image] = None, *,
stage_mask: Optional[Image.Image] = None, expand: Border,
fill_color: str = "white", stage_source: Optional[Image.Image] = None,
mask_filter: Callable = mask_filter_none, stage_mask: Optional[Image.Image] = None,
noise_source: Callable = noise_source_histogram, fill_color: str = "white",
callback: Optional[ProgressCallback] = None, mask_filter: Callable = mask_filter_none,
**kwargs, noise_source: Callable = noise_source_histogram,
) -> Image.Image: callback: Optional[ProgressCallback] = None,
params = params.with_args(**kwargs) **kwargs,
expand = expand.with_args(**kwargs) ) -> Image.Image:
source = source or stage_source params = params.with_args(**kwargs)
logger.info( expand = expand.with_args(**kwargs)
"blending image using inpaint, %s steps: %s", params.steps, params.prompt source = source or stage_source
) logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)
if stage_mask is None: if stage_mask is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black") stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image( source, stage_mask, noise, _full_dims = expand_image(
source, source,
stage_mask, stage_mask,
expand, expand,
fill=fill_color, fill=fill_color,
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter, mask_filter=mask_filter,
) )
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
pipe_type = "lpw" if params.lpw() else "inpaint"
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: add LoRAs and TIs
)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(server, "tile-source.png", tile_source) save_image(server, "last-source.png", source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_latents_from_seed(params.seed, size) pipe_type = "lpw" if params.lpw() else "inpaint"
if params.lpw(): pipe = load_pipeline(
logger.debug("using LPW pipeline for inpaint") server,
rng = torch.manual_seed(params.seed) params,
result = pipe.inpaint( pipe_type,
params.prompt, job.get_device(),
generator=rng, # TODO: add LoRAs and TIs
guidance_scale=params.cfg, )
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0] def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
output = process_tile_order( if is_debug():
stage.tile_order, source, SizeChart.auto, 1, [outpaint], overlap=params.overlap save_image(server, "tile-source.png", tile_source)
) save_image(server, "tile-mask.png", tile_mask)
logger.info("final output image size: %s", output.size) latents = get_latents_from_seed(params.seed, size)
return output if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
logger.info("final output image size: %s", output.size)
return output

View File

@ -10,17 +10,19 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def blend_linear( class BlendLinearStage:
_job: WorkerContext, def run(
_server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, _server: ServerContext,
*, _stage: StageParams,
alpha: float, _params: ImageParams,
sources: Optional[List[Image.Image]] = None, *,
_callback: Optional[ProgressCallback] = None, alpha: float,
**kwargs, sources: Optional[List[Image.Image]] = None,
) -> Image.Image: _callback: Optional[ProgressCallback] = None,
logger.info("blending image using linear interpolation") **kwargs,
) -> Image.Image:
logger.info("blending image using linear interpolation")
return Image.blend(sources[1], sources[0], alpha) return Image.blend(sources[1], sources[0], alpha)

View File

@ -12,26 +12,28 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def blend_mask( class BlendMaskStage:
_job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, server: ServerContext,
source: Image.Image, _stage: StageParams,
*, _params: ImageParams,
stage_source: Optional[Image.Image] = None, source: Image.Image,
stage_mask: Optional[Image.Image] = None, *,
_callback: Optional[ProgressCallback] = None, stage_source: Optional[Image.Image] = None,
**kwargs, stage_mask: Optional[Image.Image] = None,
) -> Image.Image: _callback: Optional[ProgressCallback] = None,
logger.info("blending image using mask") **kwargs,
) -> Image.Image:
logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black") mult_mask = Image.new("RGBA", stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask) mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L") mult_mask = mult_mask.convert("L")
if is_debug(): if is_debug():
save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
return Image.composite(stage_source, source, mult_mask) return Image.composite(stage_source, source, mult_mask)

View File

@ -9,28 +9,28 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
device = "cpu"
class CorrectCodeformerStage:
def run(
self,
job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
def correct_codeformer( source = stage_source or source
job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
source = stage_source or source upscale = upscale.with_args(**kwargs)
upscale = upscale.with_args(**kwargs) device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
device = job.get_device() return pipe(source)
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(source)

View File

@ -13,72 +13,74 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def load_gfpgan( class CorrectGFPGANStage:
server: ServerContext, def load(
_stage: StageParams, self,
upscale: UpscaleParams, server: ServerContext,
device: DeviceParams, _stage: StageParams,
): upscale: UpscaleParams,
# must be within the load function for patch to take effect device: DeviceParams,
# TODO: rewrite and remove ):
from gfpgan import GFPGANer # must be within the load function for patch to take effect
# TODO: rewrite and remove
from gfpgan import GFPGANer
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model)) face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,) cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key) cache_pipe = server.cache.get("gfpgan", cache_key)
if cache_pipe is not None: if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline") logger.info("reusing existing GFPGAN pipeline")
return cache_pipe return cache_pipe
logger.debug("loading GFPGAN model from %s", face_path) logger.debug("loading GFPGAN model from %s", face_path)
# TODO: find a way to pass the ONNX model to underlying architectures # TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer( gfpgan = GFPGANer(
arch="clean", arch="clean",
bg_upsampler=None, bg_upsampler=None,
channel_multiplier=2, channel_multiplier=2,
device=device.torch_str(), device=device.torch_str(),
model_path=face_path, model_path=face_path,
upscale=upscale.face_outscale, upscale=upscale.face_outscale,
) )
server.cache.set("gfpgan", cache_key, gfpgan) server.cache.set("gfpgan", cache_key, gfpgan)
run_gc([device]) run_gc([device])
return gfpgan return gfpgan
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
def correct_gfpgan( if upscale.correction_model is None:
job: WorkerContext, logger.warn("no face model given, skipping")
server: ServerContext, return source
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.correction_model is None: logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
logger.warn("no face model given, skipping") device = job.get_device()
return source gfpgan = self.load(server, stage, upscale, device)
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) output = np.array(source)
device = job.get_device() _, _, output = gfpgan.enhance(
gfpgan = load_gfpgan(server, stage, upscale, device) output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
output = np.array(source) return output
_, _, output = gfpgan.enhance(
output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
return output

View File

@ -10,19 +10,21 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def persist_disk( class PersistDiskStage:
_job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
params: ImageParams, server: ServerContext,
source: Image.Image, _stage: StageParams,
*, params: ImageParams,
output: str, source: Image.Image,
stage_source: Image.Image, *,
**kwargs, output: str,
) -> Image.Image: stage_source: Image.Image,
source = stage_source or source **kwargs,
) -> Image.Image:
source = stage_source or source
dest = save_image(server, output, source, params=params) dest = save_image(server, output, source, params=params)
logger.info("saved image to %s", dest) logger.info("saved image to %s", dest)
return source return source

View File

@ -12,33 +12,35 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def persist_s3( class PersistS3Stage:
_job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, server: ServerContext,
source: Image.Image, _stage: StageParams,
*, _params: ImageParams,
output: str, source: Image.Image,
bucket: str, *,
endpoint_url: Optional[str] = None, output: str,
profile_name: Optional[str] = None, bucket: str,
stage_source: Optional[Image.Image] = None, endpoint_url: Optional[str] = None,
**kwargs, profile_name: Optional[str] = None,
) -> Image.Image: stage_source: Optional[Image.Image] = None,
source = stage_source or source **kwargs,
) -> Image.Image:
source = stage_source or source
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO() data = BytesIO()
source.save(data, format=server.image_format) source.save(data, format=server.image_format)
data.seek(0) data.seek(0)
try: try:
s3.upload_fileobj(data, bucket, output) s3.upload_fileobj(data, bucket, output)
logger.info("saved image to s3://%s/%s", bucket, output) logger.info("saved image to s3://%s/%s", bucket, output)
except Exception: except Exception:
logger.exception("error saving image to S3") logger.exception("error saving image to S3")
return source return source

View File

@ -10,20 +10,24 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def reduce_crop( class ReduceCropStage:
_job: WorkerContext, def run(
_server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, _server: ServerContext,
source: Image.Image, _stage: StageParams,
*, _params: ImageParams,
origin: Size, source: Image.Image,
size: Size, *,
stage_source: Optional[Image.Image] = None, origin: Size,
**kwargs, size: Size,
) -> Image.Image: stage_source: Optional[Image.Image] = None,
source = stage_source or source **kwargs,
) -> Image.Image:
source = stage_source or source
image = source.crop((origin.width, origin.height, size.width, size.height)) image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) logger.info(
return image "created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image

View File

@ -9,21 +9,25 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def reduce_thumbnail( class ReduceThumbnailStage:
_job: WorkerContext, def run(
_server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, _server: ServerContext,
source: Image.Image, _stage: StageParams,
*, _params: ImageParams,
size: Size, source: Image.Image,
stage_source: Image.Image, *,
**kwargs, size: Size,
) -> Image.Image: stage_source: Image.Image,
source = stage_source or source **kwargs,
image = source.copy() ) -> Image.Image:
source = stage_source or source
image = source.copy()
image = image.thumbnail((size.width, size.height)) image = image.thumbnail((size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) logger.info(
return image "created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image

View File

@ -10,25 +10,29 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def source_noise( class SourceNoiseStage:
_job: WorkerContext, def run(
_server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, _server: ServerContext,
source: Image.Image, _stage: StageParams,
*, _params: ImageParams,
size: Size, source: Image.Image,
noise_source: Callable, *,
stage_source: Image.Image, size: Size,
**kwargs, noise_source: Callable,
) -> Image.Image: stage_source: Image.Image,
source = stage_source or source **kwargs,
logger.info("generating image from noise source") ) -> Image.Image:
source = stage_source or source
logger.info("generating image from noise source")
if source is not None: if source is not None:
logger.warn("a source image was passed to a noise stage, but will be discarded") logger.warn(
"a source image was passed to a noise stage, but will be discarded"
)
output = noise_source(source, (size.width, size.height), (0, 0)) output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -12,31 +12,33 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def source_s3( class SourceS3Stage:
_job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, server: ServerContext,
source: Image.Image, _stage: StageParams,
*, _params: ImageParams,
source_key: str, source: Image.Image,
bucket: str, *,
endpoint_url: Optional[str] = None, source_key: str,
profile_name: Optional[str] = None, bucket: str,
stage_source: Optional[Image.Image] = None, endpoint_url: Optional[str] = None,
**kwargs, profile_name: Optional[str] = None,
) -> Image.Image: stage_source: Optional[Image.Image] = None,
source = stage_source or source **kwargs,
) -> Image.Image:
source = stage_source or source
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
try: try:
logger.info("loading image from s3://%s/%s", bucket, source_key) logger.info("loading image from s3://%s/%s", bucket, source_key)
data = BytesIO() data = BytesIO()
s3.download_fileobj(bucket, source_key, data) s3.download_fileobj(bucket, source_key, data)
data.seek(0) data.seek(0)
return Image.open(data) return Image.open(data)
except Exception: except Exception:
logger.exception("error loading image from S3") logger.exception("error loading image from S3")

View File

@ -14,74 +14,78 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def source_txt2img( class SourceTxt2ImgStage:
job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, job: WorkerContext,
params: ImageParams, server: ServerContext,
_source: Image.Image, _stage: StageParams,
*, params: ImageParams,
size: Size, _source: Image.Image,
callback: Optional[ProgressCallback] = None, *,
**kwargs, size: Size,
) -> Image.Image: callback: Optional[ProgressCallback] = None,
params = params.with_args(**kwargs) **kwargs,
size = size.with_args(**kwargs) ) -> Image.Image:
logger.info( params = params.with_args(**kwargs)
"generating image using txt2img, %s steps: %s", params.steps, params.prompt size = size.with_args(**kwargs)
) logger.info(
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
if "stage_source" in kwargs:
logger.warn(
"a source image was passed to a txt2img stage, and will be discarded"
) )
prompt_pairs, loras, inversions = parse_prompt(params) if "stage_source" in kwargs:
logger.warn(
"a source image was passed to a txt2img stage, and will be discarded"
)
latents = get_latents_from_seed(params.seed, size) prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
if params.lpw(): latents = get_latents_from_seed(params.seed, size)
logger.debug("using LPW pipeline for txt2img") pipe_type = params.get_valid_pipeline("txt2img")
rng = torch.manual_seed(params.seed) pipe = load_pipeline(
result = pipe.text2img( server,
params.prompt, params,
height=size.height, pipe_type,
width=size.width, job.get_device(),
generator=rng, inversions=inversions,
guidance_scale=params.cfg, loras=loras,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
) )
output = result.images[0] if params.lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
logger.info("final output image size: %sx%s", output.width, output.height) rng = np.random.RandomState(params.seed)
return output result = pipe(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -11,27 +11,29 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def source_url( class SourceURLStage:
_job: WorkerContext, def run(
_server: ServerContext, self,
_stage: StageParams, _job: WorkerContext,
_params: ImageParams, _server: ServerContext,
source: Image.Image, _stage: StageParams,
*, _params: ImageParams,
source_url: str, source: Image.Image,
stage_source: Image.Image, *,
**kwargs, source_url: str,
) -> Image.Image: stage_source: Image.Image,
source = stage_source or source **kwargs,
logger.info("loading image from URL source") ) -> Image.Image:
source = stage_source or source
logger.info("loading image from URL source")
if source is not None: if source is not None:
logger.warn( logger.warn(
"a source image was passed to a source stage, and will be discarded" "a source image was passed to a source stage, and will be discarded"
) )
response = requests.get(source_url) response = requests.get(source_url)
output = Image.open(BytesIO(response.content)) output = Image.open(BytesIO(response.content))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -0,0 +1,31 @@
from typing import Optional
from PIL import Image
from onnx_web.params import ImageParams, Size, SizeChart, StageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
raise NotImplementedError()
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
raise NotImplementedError()

View File

@ -1,14 +1,14 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..chain import ChainPipeline, PipelineStage
from ..chain.correct_codeformer import correct_codeformer
from ..chain.correct_gfpgan import correct_gfpgan
from ..chain.upscale_bsrgan import upscale_bsrgan
from ..chain.upscale_resrgan import upscale_resrgan
from ..chain.upscale_stable_diffusion import upscale_stable_diffusion
from ..chain.upscale_swinir import upscale_swinir
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from . import ChainPipeline, PipelineStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
logger = getLogger(__name__) logger = getLogger(__name__)
@ -72,23 +72,23 @@ def stage_upscale_correction(
tile_size=stage.tile_size, tile_size=stage.tile_size,
outscale=upscale.outscale, outscale=upscale.outscale,
) )
upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts) upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model: elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams( esrgan_params = StageParams(
tile_size=stage.tile_size, tile_size=stage.tile_size,
outscale=upscale.outscale, outscale=upscale.outscale,
) )
upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts) upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model: elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size) mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts) upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model: elif "swinir" in upscale.upscale_model:
swinir_params = StageParams( swinir_params = StageParams(
tile_size=stage.tile_size, tile_size=stage.tile_size,
outscale=upscale.outscale, outscale=upscale.outscale,
) )
upscale_stage = (upscale_swinir, swinir_params, upscale_opts) upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
else: else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model) logger.warn("unknown upscaling model: %s", upscale.upscale_model)
@ -98,9 +98,9 @@ def stage_upscale_correction(
tile_size=stage.tile_size, outscale=upscale.face_outscale tile_size=stage.tile_size, outscale=upscale.face_outscale
) )
if "codeformer" in upscale.correction_model: if "codeformer" in upscale.correction_model:
correct_stage = (correct_codeformer, face_params, upscale_opts) correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
elif "gfpgan" in upscale.correction_model: elif "gfpgan" in upscale.correction_model:
correct_stage = (correct_gfpgan, face_params, upscale_opts) correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
else: else:
logger.warn("unknown correction model: %s", upscale.correction_model) logger.warn("unknown correction model: %s", upscale.correction_model)

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image from PIL import Image
from ..models.onnx import OnnxModel from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
@ -14,105 +14,121 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def load_bsrgan( class UpscaleBSRGANStage:
server: ServerContext, max_tile = 64
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key)
if cache_pipe is not None: def load(
logger.debug("reusing existing BSRGAN pipeline") self,
return cache_pipe server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key)
logger.info("loading BSRGAN model from %s", model_path) if cache_pipe is not None:
logger.debug("reusing existing BSRGAN pipeline")
return cache_pipe
pipe = OnnxModel( logger.info("loading BSRGAN model from %s", model_path)
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
server.cache.set("bsrgan", cache_key, pipe) pipe = OnnxModel(
run_gc([device]) server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
return pipe server.cache.set("bsrgan", cache_key, pipe)
run_gc([device])
return pipe
def upscale_bsrgan( def run(
job: WorkerContext, self,
server: ServerContext, job: WorkerContext,
stage: StageParams, server: ServerContext,
_params: ImageParams, stage: StageParams,
source: Image.Image, _params: ImageParams,
*, source: Image.Image,
upscale: UpscaleParams, *,
stage_source: Optional[Image.Image] = None, upscale: UpscaleParams,
**kwargs, stage_source: Optional[Image.Image] = None,
) -> Image.Image: **kwargs,
upscale = upscale.with_args(**kwargs) ) -> Image.Image:
source = stage_source or source upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None: if upscale.upscale_model is None:
logger.warn("no upscaling model given, skipping") logger.warn("no upscaling model given, skipping")
return source return source
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model) logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device() device = job.get_device()
bsrgan = load_bsrgan(server, stage, upscale, device) bsrgan = self.load(server, stage, upscale, device)
tile_size = (64, 64) tile_size = (64, 64)
tile_x = source.width // tile_size[0] tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1] tile_y = source.height // tile_size[1]
image = np.array(source) / 255.0 image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape) logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale) (
) image.shape[0],
logger.trace("BSRGAN output shape: %s", dest.shape) image.shape[1],
image.shape[2] * scale,
for x in range(tile_x): image.shape[3] * scale,
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.debug(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
) )
)
logger.trace("BSRGAN output shape: %s", dest.shape)
dest[ for x in range(tile_x):
:, for y in range(tile_y):
:, xt = x * tile_size[0]
ix1 * scale : ix2 * scale, yt = y * tile_size[1]
iy1 * scale : iy2 * scale,
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) ix1 = xt
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) ix2 = xt + tile_size[0]
dest = (dest * 255.0).round().astype(np.uint8) iy1 = yt
iy2 = yt + tile_size[1]
logger.debug(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
output = Image.fromarray(dest, "RGB") dest[
logger.debug("output image size: %s x %s", output.width, output.height) :,
return output :,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
return output
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
return size.width // self.max_tile * size.height // self.max_tile

View File

@ -3,70 +3,71 @@ from typing import Any, Optional
from PIL import Image from PIL import Image
from ..chain.base import ChainPipeline from ..chain import BlendImg2ImgStage, ChainPipeline
from ..chain.blend_img2img import blend_img2img
from ..diffusers.upscale import stage_upscale_correction
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from ..worker.context import ProgressCallback from ..worker.context import ProgressCallback
from .upscale import stage_upscale_correction
logger = getLogger(__name__) logger = getLogger(__name__)
def upscale_highres( class UpscaleHighresStage:
job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, job: WorkerContext,
params: ImageParams, server: ServerContext,
source: Image.Image, _stage: StageParams,
*, params: ImageParams,
highres: HighresParams, source: Image.Image,
upscale: UpscaleParams, *,
stage_source: Optional[Image.Image] = None, highres: HighresParams,
pipeline: Optional[Any] = None, upscale: UpscaleParams,
callback: Optional[ProgressCallback] = None, stage_source: Optional[Image.Image] = None,
**kwargs, pipeline: Optional[Any] = None,
) -> Image.Image: callback: Optional[ProgressCallback] = None,
source = stage_source or source **kwargs,
) -> Image.Image:
source = stage_source or source
if highres.scale <= 1: if highres.scale <= 1:
return source return source
chain = ChainPipeline() chain = ChainPipeline()
scaled_size = (source.width * highres.scale, source.height * highres.scale) scaled_size = (source.width * highres.scale, source.height * highres.scale)
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM # TODO: upscaling within the same stage prevents tiling from happening and causes OOM
if highres.method == "bilinear": if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres") logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos": elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres") logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else: else:
logger.debug("using upscaling pipeline for highres") logger.debug("using upscaling pipeline for highres")
stage_upscale_correction( stage_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
)
chain.stage(
BlendImg2ImgStage(),
StageParams(), StageParams(),
params, overlap=params.overlap,
upscale=upscale.with_args( strength=highres.strength,
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
) )
chain.stage( return chain(
blend_img2img, job,
StageParams(), server,
overlap=params.overlap, params,
strength=highres.strength, source,
) callback=callback,
)
return chain(
job,
server,
params,
source,
callback=callback,
)

View File

@ -18,138 +18,140 @@ from .utils import complete_tile, process_tile_grid, process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
def upscale_outpaint( class UpscaleOutpaintStage:
job: WorkerContext, def run(
server: ServerContext, self,
stage: StageParams, job: WorkerContext,
params: ImageParams, server: ServerContext,
source: Image.Image, stage: StageParams,
*, params: ImageParams,
border: Border, source: Image.Image,
stage_source: Optional[Image.Image] = None, *,
stage_mask: Optional[Image.Image] = None, border: Border,
fill_color: str = "white", stage_source: Optional[Image.Image] = None,
mask_filter: Callable = mask_filter_none, stage_mask: Optional[Image.Image] = None,
noise_source: Callable = noise_source_histogram, fill_color: str = "white",
callback: Optional[ProgressCallback] = None, mask_filter: Callable = mask_filter_none,
**kwargs, noise_source: Callable = noise_source_histogram,
) -> Image.Image: callback: Optional[ProgressCallback] = None,
source = stage_source or source **kwargs,
logger.info( ) -> Image.Image:
"upscaling %s x %s image by expanding borders: %s", source = stage_source or source
source.width, logger.info(
source.height, "upscaling %s x %s image by expanding borders: %s",
border, source.width,
) source.height,
border,
)
margin_x = float(max(border.left, border.right)) margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom)) margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height) overlap = min(margin_x / source.width, margin_y / source.height)
if stage_mask is None: if stage_mask is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black") stage_mask = Image.new("RGB", source.size, "black")
# masks start as 512x512, resize to cover the source, then trim the extra # masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height) mask_max = max(source.width, source.height)
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max)) stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
stage_mask = stage_mask.crop((0, 0, source.width, source.height)) stage_mask = stage_mask.crop((0, 0, source.width, source.height))
source, stage_mask, noise, full_size = expand_image( source, stage_mask, noise, full_size = expand_image(
source, source,
stage_mask, stage_mask,
border, border,
fill=fill_color, fill=fill_color,
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter, mask_filter=mask_filter,
) )
full_latents = get_latents_from_seed(params.seed, Size(*full_size)) full_latents = get_latents_from_seed(params.seed, Size(*full_size))
draw_mask = ImageDraw.Draw(stage_mask) draw_mask = ImageDraw.Draw(stage_mask)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
if is_debug(): if is_debug():
save_image(server, "tile-source.png", tile_source) save_image(server, "last-source.png", source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_tile_latents(full_latents, dims, size) def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline) left, top, tile = dims
pipe = load_pipeline( size = Size(*tile_source.size)
server, tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
params, tile_mask = complete_tile(tile_mask, tile)
pipe_type,
job.get_device(), if is_debug():
# TODO: load LoRAs and TIs save_image(server, "tile-source.png", tile_source)
) save_image(server, "tile-mask.png", tile_mask)
if params.lpw():
logger.debug("using LPW pipeline for inpaint") latents = get_tile_latents(full_latents, dims, size)
rng = torch.manual_seed(params.seed) pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
result = pipe.inpaint( pipe = load_pipeline(
tile_source, server,
tile_mask, params,
params.prompt, pipe_type,
generator=rng, job.get_device(),
guidance_scale=params.cfg, # TODO: load LoRAs and TIs
height=size.height, )
latents=latents, if params.lpw():
negative_prompt=params.negative_prompt, logger.debug("using LPW pipeline for inpaint")
num_inference_steps=params.steps, rng = torch.manual_seed(params.seed)
width=size.width, result = pipe.inpaint(
callback=callback, tile_source,
tile_mask,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
) )
else: else:
rng = np.random.RandomState(params.seed) logger.debug("outpainting with an uneven border, using grid tiling")
result = pipe( output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
# once part of the image has been drawn, keep it logger.info("final output image size: %sx%s", output.width, output.height)
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") return output
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -16,79 +16,80 @@ logger = getLogger(__name__)
TAG_X4_V3 = "real-esrgan-x4-v3" TAG_X4_V3 = "real-esrgan-x4-v3"
def load_resrgan( class UpscaleRealESRGANStage:
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 def load(
): self, server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
# must be within load function for patches to take effect ):
# TODO: rewrite and remove # must be within load function for patches to take effect
from realesrgan import RealESRGANer # TODO: rewrite and remove
from realesrgan import RealESRGANer
model_file = "%s.%s" % (params.upscale_model, params.format) model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file) model_path = path.join(server.model_path, model_file)
cache_key = (model_path, params.format) cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key) cache_pipe = server.cache.get("resrgan", cache_key)
if cache_pipe is not None: if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline") logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe return cache_pipe
if not path.isfile(model_path): if not path.isfile(model_path):
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path) raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
# TODO: swap for regular RRDBNet after rewriting wrapper # TODO: swap for regular RRDBNet after rewriting wrapper
model = OnnxRRDBNet( model = OnnxRRDBNet(
server, server,
model_file, model_file,
provider=device.ort_provider(), provider=device.ort_provider(),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
dni_weight = None dni_weight = None
if params.upscale_model == TAG_X4_V3 and params.denoise != 1: if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3) wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
model_path = [model_path, wdn_model_path] model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise] dni_weight = [params.denoise, 1 - params.denoise]
logger.debug("loading Real ESRGAN upscale model from %s", model_path) logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file # TODO: shouldn't need the PTH file
model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model)) model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model))
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale=params.scale, scale=params.scale,
model_path=model_path_pth, model_path=model_path_pth,
dni_weight=dni_weight, dni_weight=dni_weight,
model=model, model=model,
tile=tile, tile=tile,
tile_pad=params.tile_pad, tile_pad=params.tile_pad,
pre_pad=params.pre_pad, pre_pad=params.pre_pad,
half=False, # TODO: use server optimizations half=False, # TODO: use server optimizations
) )
server.cache.set("resrgan", cache_key, upsampler) server.cache.set("resrgan", cache_key, upsampler)
run_gc([device]) run_gc([device])
return upsampler return upsampler
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
def upscale_resrgan( output = np.array(source)
job: WorkerContext, upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size)
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source) output, _ = upsampler.enhance(output, outscale=upscale.outscale)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale) output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
output = Image.fromarray(output, "RGB") return output
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -14,52 +14,54 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def upscale_stable_diffusion( class UpscaleStableDiffusionStage:
job: WorkerContext, def run(
server: ServerContext, self,
_stage: StageParams, job: WorkerContext,
params: ImageParams, server: ServerContext,
source: Image.Image, _stage: StageParams,
*, params: ImageParams,
upscale: UpscaleParams, source: Image.Image,
stage_source: Optional[Image.Image] = None, *,
callback: Optional[ProgressCallback] = None, upscale: UpscaleParams,
**kwargs, stage_source: Optional[Image.Image] = None,
) -> Image.Image: callback: Optional[ProgressCallback] = None,
params = params.with_args(**kwargs) **kwargs,
upscale = upscale.with_args(**kwargs) ) -> Image.Image:
source = stage_source or source params = params.with_args(**kwargs)
logger.info( upscale = upscale.with_args(**kwargs)
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt source = stage_source or source
) logger.info(
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
)
prompt_pairs, _loras, _inversions = parse_prompt(params) prompt_pairs, _loras, _inversions = parse_prompt(params)
pipeline = load_pipeline( pipeline = load_pipeline(
server, server,
params, params,
"upscale", "upscale",
job.get_device(), job.get_device(),
model=path.join(server.model_path, upscale.upscale_model), model=path.join(server.model_path, upscale.upscale_model),
) )
generator = torch.manual_seed(params.seed) generator = torch.manual_seed(params.seed)
prompt_embeds = encode_prompt( prompt_embeds = encode_prompt(
pipeline, pipeline,
prompt_pairs, prompt_pairs,
num_images_per_prompt=params.batch, num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(), do_classifier_free_guidance=params.do_cfg(),
) )
pipeline.unet.set_prompts(prompt_embeds) pipeline.unet.set_prompts(prompt_embeds)
return pipeline( return pipeline(
params.prompt, params.prompt,
source, source,
generator=generator, generator=generator,
guidance_scale=params.cfg, guidance_scale=params.cfg,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
eta=params.eta, eta=params.eta,
noise_level=upscale.denoise, noise_level=upscale.denoise,
callback=callback, callback=callback,
).images[0] ).images[0]

View File

@ -14,107 +14,116 @@ from ..worker import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def load_swinir( class UpscaleSwinIRStage:
server: ServerContext, max_tile = 64
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key)
if cache_pipe is not None: def load(
logger.info("reusing existing SwinIR pipeline") self,
return cache_pipe server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key)
logger.debug("loading SwinIR model from %s", model_path) if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline")
return cache_pipe
pipe = OnnxModel( logger.debug("loading SwinIR model from %s", model_path)
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
server.cache.set("swinir", cache_key, pipe) pipe = OnnxModel(
run_gc([device]) server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
return pipe server.cache.set("swinir", cache_key, pipe)
run_gc([device])
return pipe
def upscale_swinir( def run(
job: WorkerContext, self,
server: ServerContext, job: WorkerContext,
stage: StageParams, server: ServerContext,
_params: ImageParams, stage: StageParams,
source: Image.Image, _params: ImageParams,
*, source: Image.Image,
upscale: UpscaleParams, *,
stage_source: Optional[Image.Image] = None, upscale: UpscaleParams,
**kwargs, stage_source: Optional[Image.Image] = None,
) -> Image.Image: **kwargs,
upscale = upscale.with_args(**kwargs) ) -> Image.Image:
source = stage_source or source upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None: if upscale.upscale_model is None:
logger.warn("no correction model given, skipping") logger.warn("no correction model given, skipping")
return source return source
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model) logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device() device = job.get_device()
swinir = load_swinir(server, stage, upscale, device) swinir = self.load(server, stage, upscale, device)
# TODO: add support for other sizes # TODO: add support for other sizes
tile_size = (64, 64) tile_size = (64, 64)
tile_x = source.width // tile_size[0] tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1] tile_y = source.height // tile_size[1]
# TODO: add support for grayscale (1-channel) images # TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0 image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
logger.info("SwinIR input shape: %s", image.shape) logger.info("SwinIR input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale) (
) image.shape[0],
logger.info("SwinIR output shape: %s", dest.shape) image.shape[1],
image.shape[2] * scale,
for x in range(tile_x): image.shape[3] * scale,
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
) )
)
logger.info("SwinIR output shape: %s", dest.shape)
dest[ for x in range(tile_x):
:, for y in range(tile_y):
:, xt = x * tile_size[0]
ix1 * scale : ix2 * scale, yt = y * tile_size[1]
iy1 * scale : iy2 * scale,
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) ix1 = xt
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) ix2 = xt + tile_size[0]
dest = (dest * 255.0).round().astype(np.uint8) iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
output = Image.fromarray(dest, "RGB") dest[
logger.info("output image size: %s x %s", output.width, output.height) :,
return output :,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
return output

View File

@ -4,13 +4,14 @@ from typing import Any, List, Optional
from PIL import Image from PIL import Image
from ..chain import ( from ..chain import (
blend_img2img, BlendImg2ImgStage,
blend_mask, BlendMaskStage,
source_txt2img, ChainPipeline,
upscale_highres, SourceTxt2ImgStage,
upscale_outpaint, UpscaleHighresStage,
UpscaleOutpaintStage,
) )
from ..chain.base import ChainPipeline from ..chain.upscale import split_upscale, stage_upscale_correction
from ..output import save_image from ..output import save_image
from ..params import ( from ..params import (
Border, Border,
@ -24,7 +25,6 @@ from ..server import ServerContext
from ..server.load import get_source_filters from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast from ..utils import run_gc, show_system_toast
from ..worker import WorkerContext from ..worker import WorkerContext
from .upscale import split_upscale, stage_upscale_correction
from .utils import parse_prompt from .utils import parse_prompt
logger = getLogger(__name__) logger = getLogger(__name__)
@ -43,7 +43,7 @@ def run_txt2img_pipeline(
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams()
chain.stage( chain.stage(
source_txt2img, SourceTxt2ImgStage(),
stage, stage,
size=size, size=size,
) )
@ -61,7 +61,7 @@ def run_txt2img_pipeline(
# apply highres # apply highres
for _i in range(highres.iterations): for _i in range(highres.iterations):
chain.stage( chain.stage(
upscale_highres, UpscaleHighresStage(),
StageParams( StageParams(
outscale=highres.scale, outscale=highres.scale,
), ),
@ -125,7 +125,7 @@ def run_img2img_pipeline(
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams()
chain.stage( chain.stage(
blend_img2img, BlendImg2ImgStage(),
stage, stage,
strength=strength, strength=strength,
) )
@ -144,7 +144,7 @@ def run_img2img_pipeline(
if params.loopback > 0: if params.loopback > 0:
for _i in range(params.loopback): for _i in range(params.loopback):
chain.stage( chain.stage(
blend_img2img, BlendImg2ImgStage(),
stage, stage,
strength=strength, strength=strength,
) )
@ -153,7 +153,7 @@ def run_img2img_pipeline(
if highres.iterations > 0: if highres.iterations > 0:
for _i in range(highres.iterations): for _i in range(highres.iterations):
chain.stage( chain.stage(
upscale_highres, UpscaleHighresStage(),
stage, stage,
highres=highres, highres=highres,
upscale=upscale, upscale=upscale,
@ -223,7 +223,7 @@ def run_inpaint_pipeline(
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_order=tile_order) stage = StageParams(tile_order=tile_order)
chain.stage( chain.stage(
upscale_outpaint, UpscaleOutpaintStage(),
stage, stage,
border=border, border=border,
stage_mask=mask, stage_mask=mask,
@ -234,7 +234,7 @@ def run_inpaint_pipeline(
# apply highres # apply highres
chain.stage( chain.stage(
upscale_highres, UpscaleHighresStage(),
stage, stage,
highres=highres, highres=highres,
upscale=upscale, upscale=upscale,
@ -300,7 +300,7 @@ def run_upscale_pipeline(
# apply highres # apply highres
chain.stage( chain.stage(
upscale_highres, UpscaleHighresStage(),
stage, stage,
highres=highres, highres=highres,
upscale=upscale, upscale=upscale,
@ -353,7 +353,7 @@ def run_blend_pipeline(
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() stage = StageParams()
chain.stage(blend_mask, stage, stage_source=sources[1], stage_mask=mask) chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
# apply upscaling and correction # apply upscaling and correction
stage_upscale_correction( stage_upscale_correction(

View File

@ -1,6 +1,4 @@
from typing import Tuple, Union from PIL import Image, ImageChops
from PIL import Image, ImageChops, ImageOps
from ..params import Border, Size from ..params import Border, Size
from .mask_filter import mask_filter_none from .mask_filter import mask_filter_none

View File

@ -356,12 +356,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
pipeline = ChainPipeline() pipeline = ChainPipeline()
for stage_data in data.get("stages", []): for stage_data in data.get("stages", []):
callback = CHAIN_STAGES[stage_data.get("type")] stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs = stage_data.get("params", {}) kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs) logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
stage = StageParams( stage = StageParams(
stage_data.get("name", callback.__name__), stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")), tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
) )
@ -399,7 +399,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
mask = Image.open(BytesIO(mask_file.read())).convert("RGB") mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
kwargs["stage_mask"] = mask kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs)) pipeline.append((stage_class(), stage, kwargs))
logger.info("running chain pipeline with %s stages", len(pipeline.stages)) logger.info("running chain pipeline with %s stages", len(pipeline.stages))