feat(api): make chain stages into classes with max tile size and step count estimate
This commit is contained in:
parent
5e1b70091c
commit
2913cd0382
|
@ -17,7 +17,7 @@ from .diffusers.run import (
|
|||
run_upscale_pipeline,
|
||||
)
|
||||
from .diffusers.stub_scheduler import StubScheduler
|
||||
from .diffusers.upscale import stage_upscale_correction
|
||||
from .chain.upscale import stage_upscale_correction
|
||||
from .image.utils import (
|
||||
expand_image,
|
||||
)
|
||||
|
|
|
@ -1,44 +1,44 @@
|
|||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||
from .blend_img2img import blend_img2img
|
||||
from .blend_inpaint import blend_inpaint
|
||||
from .blend_linear import blend_linear
|
||||
from .blend_mask import blend_mask
|
||||
from .correct_codeformer import correct_codeformer
|
||||
from .correct_gfpgan import correct_gfpgan
|
||||
from .persist_disk import persist_disk
|
||||
from .persist_s3 import persist_s3
|
||||
from .reduce_crop import reduce_crop
|
||||
from .reduce_thumbnail import reduce_thumbnail
|
||||
from .source_noise import source_noise
|
||||
from .source_s3 import source_s3
|
||||
from .source_txt2img import source_txt2img
|
||||
from .source_url import source_url
|
||||
from .upscale_bsrgan import upscale_bsrgan
|
||||
from .upscale_highres import upscale_highres
|
||||
from .upscale_outpaint import upscale_outpaint
|
||||
from .upscale_resrgan import upscale_resrgan
|
||||
from .upscale_stable_diffusion import upscale_stable_diffusion
|
||||
from .upscale_swinir import upscale_swinir
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_inpaint import BlendInpaintStage
|
||||
from .blend_linear import BlendLinearStage
|
||||
from .blend_mask import BlendMaskStage
|
||||
from .correct_codeformer import CorrectCodeformerStage
|
||||
from .correct_gfpgan import CorrectGFPGANStage
|
||||
from .persist_disk import PersistDiskStage
|
||||
from .persist_s3 import PersistS3Stage
|
||||
from .reduce_crop import ReduceCropStage
|
||||
from .reduce_thumbnail import ReduceThumbnailStage
|
||||
from .source_noise import SourceNoiseStage
|
||||
from .source_s3 import SourceS3Stage
|
||||
from .source_txt2img import SourceTxt2ImgStage
|
||||
from .source_url import SourceURLStage
|
||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||
from .upscale_highres import UpscaleHighresStage
|
||||
from .upscale_outpaint import UpscaleOutpaintStage
|
||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||
from .upscale_swinir import UpscaleSwinIRStage
|
||||
|
||||
CHAIN_STAGES = {
|
||||
"blend-img2img": blend_img2img,
|
||||
"blend-inpaint": blend_inpaint,
|
||||
"blend-linear": blend_linear,
|
||||
"blend-mask": blend_mask,
|
||||
"correct-codeformer": correct_codeformer,
|
||||
"correct-gfpgan": correct_gfpgan,
|
||||
"persist-disk": persist_disk,
|
||||
"persist-s3": persist_s3,
|
||||
"reduce-crop": reduce_crop,
|
||||
"reduce-thumbnail": reduce_thumbnail,
|
||||
"source-noise": source_noise,
|
||||
"source-s3": source_s3,
|
||||
"source-txt2img": source_txt2img,
|
||||
"source-url": source_url,
|
||||
"upscale-bsrgan": upscale_bsrgan,
|
||||
"upscale-highres": upscale_highres,
|
||||
"upscale-outpaint": upscale_outpaint,
|
||||
"upscale-resrgan": upscale_resrgan,
|
||||
"upscale-stable-diffusion": upscale_stable_diffusion,
|
||||
"upscale-swinir": upscale_swinir,
|
||||
"blend-img2img": BlendImg2ImgStage,
|
||||
"blend-inpaint": BlendInpaintStage,
|
||||
"blend-linear": BlendLinearStage,
|
||||
"blend-mask": BlendMaskStage,
|
||||
"correct-codeformer": CorrectCodeformerStage,
|
||||
"correct-gfpgan": CorrectGFPGANStage,
|
||||
"persist-disk": PersistDiskStage,
|
||||
"persist-s3": PersistS3Stage,
|
||||
"reduce-crop": ReduceCropStage,
|
||||
"reduce-thumbnail": ReduceThumbnailStage,
|
||||
"source-noise": SourceNoiseStage,
|
||||
"source-s3": SourceS3Stage,
|
||||
"source-txt2img": SourceTxt2ImgStage,
|
||||
"source-url": SourceURLStage,
|
||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||
"upscale-highres": UpscaleHighresStage,
|
||||
"upscale-outpaint": UpscaleOutpaintStage,
|
||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
||||
"upscale-swinir": UpscaleSwinIRStage,
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ from ..params import ImageParams, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .utils import process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -35,7 +36,7 @@ class StageCallback(Protocol):
|
|||
pass
|
||||
|
||||
|
||||
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
||||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||
|
||||
|
||||
class ChainProgress:
|
||||
|
@ -131,7 +132,7 @@ class ChainPipeline:
|
|||
logger.info("running pipeline without source image")
|
||||
|
||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
name = stage_params.name or stage_pipe.__name__
|
||||
name = stage_params.name or stage_pipe.__class__.__name__
|
||||
kwargs = stage_kwargs or {}
|
||||
kwargs = {**pipeline_kwargs, **kwargs}
|
||||
|
||||
|
@ -158,7 +159,7 @@ class ChainPipeline:
|
|||
)
|
||||
|
||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||
tile = stage_pipe(
|
||||
tile = stage_pipe.run(
|
||||
job,
|
||||
server,
|
||||
stage_params,
|
||||
|
@ -182,7 +183,7 @@ class ChainPipeline:
|
|||
)
|
||||
else:
|
||||
logger.debug("image within tile size, running stage")
|
||||
image = stage_pipe(
|
||||
image = stage_pipe.run(
|
||||
job,
|
||||
server,
|
||||
stage_params,
|
||||
|
|
|
@ -14,7 +14,9 @@ from ..worker import ProgressCallback, WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_img2img(
|
||||
class BlendImg2ImgStage:
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
@ -69,7 +71,9 @@ def blend_img2img(
|
|||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
|
|
@ -18,7 +18,9 @@ from .utils import process_tile_order
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_inpaint(
|
||||
class BlendInpaintStage:
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
|
@ -115,7 +117,12 @@ def blend_inpaint(
|
|||
return result.images[0]
|
||||
|
||||
output = process_tile_order(
|
||||
stage.tile_order, source, SizeChart.auto, 1, [outpaint], overlap=params.overlap
|
||||
stage.tile_order,
|
||||
source,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=params.overlap,
|
||||
)
|
||||
|
||||
logger.info("final output image size: %s", output.size)
|
||||
|
|
|
@ -10,7 +10,9 @@ from ..worker import ProgressCallback, WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_linear(
|
||||
class BlendLinearStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -12,7 +12,9 @@ from ..worker import ProgressCallback, WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_mask(
|
||||
class BlendMaskStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -9,10 +9,10 @@ from ..worker import WorkerContext
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
device = "cpu"
|
||||
|
||||
|
||||
def correct_codeformer(
|
||||
class CorrectCodeformerStage:
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -13,7 +13,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def load_gfpgan(
|
||||
class CorrectGFPGANStage:
|
||||
def load(
|
||||
self,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
|
@ -48,8 +50,8 @@ def load_gfpgan(
|
|||
|
||||
return gfpgan
|
||||
|
||||
|
||||
def correct_gfpgan(
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
|
@ -69,7 +71,7 @@ def correct_gfpgan(
|
|||
|
||||
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||
device = job.get_device()
|
||||
gfpgan = load_gfpgan(server, stage, upscale, device)
|
||||
gfpgan = self.load(server, stage, upscale, device)
|
||||
|
||||
output = np.array(source)
|
||||
_, _, output = gfpgan.enhance(
|
||||
|
|
|
@ -10,7 +10,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def persist_disk(
|
||||
class PersistDiskStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -12,7 +12,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def persist_s3(
|
||||
class PersistS3Stage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -10,7 +10,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_crop(
|
||||
class ReduceCropStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
@ -25,5 +27,7 @@ def reduce_crop(
|
|||
source = stage_source or source
|
||||
|
||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
return image
|
||||
|
|
|
@ -9,7 +9,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_thumbnail(
|
||||
class ReduceThumbnailStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
@ -25,5 +27,7 @@ def reduce_thumbnail(
|
|||
|
||||
image = image.thumbnail((size.width, size.height))
|
||||
|
||||
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
return image
|
||||
|
|
|
@ -10,7 +10,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_noise(
|
||||
class SourceNoiseStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
@ -26,7 +28,9 @@ def source_noise(
|
|||
logger.info("generating image from noise source")
|
||||
|
||||
if source is not None:
|
||||
logger.warn("a source image was passed to a noise stage, but will be discarded")
|
||||
logger.warn(
|
||||
"a source image was passed to a noise stage, but will be discarded"
|
||||
)
|
||||
|
||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_s3(
|
||||
class SourceS3Stage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -14,7 +14,9 @@ from ..worker import ProgressCallback, WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_txt2img(
|
||||
class SourceTxt2ImgStage:
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
@ -65,7 +67,9 @@ def source_txt2img(
|
|||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
|
|
@ -11,7 +11,9 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def source_url(
|
||||
class SourceURLStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.params import ImageParams, Size, SizeChart, StageParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
|
||||
|
||||
class BaseStage:
|
||||
max_tile = SizeChart.auto
|
||||
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
*args,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
raise NotImplementedError()
|
||||
|
||||
def steps(
|
||||
self,
|
||||
_params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
|
@ -1,14 +1,14 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..chain import ChainPipeline, PipelineStage
|
||||
from ..chain.correct_codeformer import correct_codeformer
|
||||
from ..chain.correct_gfpgan import correct_gfpgan
|
||||
from ..chain.upscale_bsrgan import upscale_bsrgan
|
||||
from ..chain.upscale_resrgan import upscale_resrgan
|
||||
from ..chain.upscale_stable_diffusion import upscale_stable_diffusion
|
||||
from ..chain.upscale_swinir import upscale_swinir
|
||||
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from . import ChainPipeline, PipelineStage
|
||||
from .correct_codeformer import CorrectCodeformerStage
|
||||
from .correct_gfpgan import CorrectGFPGANStage
|
||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||
from .upscale_swinir import UpscaleSwinIRStage
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -72,23 +72,23 @@ def stage_upscale_correction(
|
|||
tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale,
|
||||
)
|
||||
upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts)
|
||||
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
|
||||
elif "esrgan" in upscale.upscale_model:
|
||||
esrgan_params = StageParams(
|
||||
tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale,
|
||||
)
|
||||
upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts)
|
||||
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
|
||||
elif "stable-diffusion" in upscale.upscale_model:
|
||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts)
|
||||
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
|
||||
elif "swinir" in upscale.upscale_model:
|
||||
swinir_params = StageParams(
|
||||
tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale,
|
||||
)
|
||||
upscale_stage = (upscale_swinir, swinir_params, upscale_opts)
|
||||
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
|
||||
else:
|
||||
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||
|
||||
|
@ -98,9 +98,9 @@ def stage_upscale_correction(
|
|||
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
||||
)
|
||||
if "codeformer" in upscale.correction_model:
|
||||
correct_stage = (correct_codeformer, face_params, upscale_opts)
|
||||
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
|
||||
elif "gfpgan" in upscale.correction_model:
|
||||
correct_stage = (correct_gfpgan, face_params, upscale_opts)
|
||||
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
|
||||
else:
|
||||
logger.warn("unknown correction model: %s", upscale.correction_model)
|
||||
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
@ -14,7 +14,11 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def load_bsrgan(
|
||||
class UpscaleBSRGANStage:
|
||||
max_tile = 64
|
||||
|
||||
def load(
|
||||
self,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
|
@ -43,8 +47,8 @@ def load_bsrgan(
|
|||
|
||||
return pipe
|
||||
|
||||
|
||||
def upscale_bsrgan(
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
|
@ -64,7 +68,7 @@ def upscale_bsrgan(
|
|||
|
||||
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
bsrgan = load_bsrgan(server, stage, upscale, device)
|
||||
bsrgan = self.load(server, stage, upscale, device)
|
||||
|
||||
tile_size = (64, 64)
|
||||
tile_x = source.width // tile_size[0]
|
||||
|
@ -77,7 +81,12 @@ def upscale_bsrgan(
|
|||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
)
|
||||
)
|
||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||
|
||||
|
@ -116,3 +125,10 @@ def upscale_bsrgan(
|
|||
output = Image.fromarray(dest, "RGB")
|
||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
||||
return output
|
||||
|
||||
def steps(
|
||||
self,
|
||||
_params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
return size.width // self.max_tile * size.height // self.max_tile
|
||||
|
|
|
@ -3,18 +3,19 @@ from typing import Any, Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..chain.base import ChainPipeline
|
||||
from ..chain.blend_img2img import blend_img2img
|
||||
from ..diffusers.upscale import stage_upscale_correction
|
||||
from ..chain import BlendImg2ImgStage, ChainPipeline
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker.context import ProgressCallback
|
||||
from .upscale import stage_upscale_correction
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def upscale_highres(
|
||||
class UpscaleHighresStage:
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
@ -57,7 +58,7 @@ def upscale_highres(
|
|||
)
|
||||
|
||||
chain.stage(
|
||||
blend_img2img,
|
||||
BlendImg2ImgStage(),
|
||||
StageParams(),
|
||||
overlap=params.overlap,
|
||||
strength=highres.strength,
|
||||
|
|
|
@ -18,7 +18,9 @@ from .utils import complete_tile, process_tile_grid, process_tile_order
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def upscale_outpaint(
|
||||
class UpscaleOutpaintStage:
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
|
|
|
@ -16,8 +16,9 @@ logger = getLogger(__name__)
|
|||
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||
|
||||
|
||||
def load_resrgan(
|
||||
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||
class UpscaleRealESRGANStage:
|
||||
def load(
|
||||
self, server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||
):
|
||||
# must be within load function for patches to take effect
|
||||
# TODO: rewrite and remove
|
||||
|
@ -69,8 +70,8 @@ def load_resrgan(
|
|||
|
||||
return upsampler
|
||||
|
||||
|
||||
def upscale_resrgan(
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
|
@ -85,7 +86,7 @@ def upscale_resrgan(
|
|||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||
|
||||
output = np.array(source)
|
||||
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
|
||||
upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
||||
|
|
|
@ -14,7 +14,9 @@ from ..worker import ProgressCallback, WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def upscale_stable_diffusion(
|
||||
class UpscaleStableDiffusionStage:
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
|
|
|
@ -14,7 +14,11 @@ from ..worker import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def load_swinir(
|
||||
class UpscaleSwinIRStage:
|
||||
max_tile = 64
|
||||
|
||||
def load(
|
||||
self,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
|
@ -43,8 +47,8 @@ def load_swinir(
|
|||
|
||||
return pipe
|
||||
|
||||
|
||||
def upscale_swinir(
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
|
@ -64,7 +68,7 @@ def upscale_swinir(
|
|||
|
||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
swinir = load_swinir(server, stage, upscale, device)
|
||||
swinir = self.load(server, stage, upscale, device)
|
||||
|
||||
# TODO: add support for other sizes
|
||||
tile_size = (64, 64)
|
||||
|
@ -79,7 +83,12 @@ def upscale_swinir(
|
|||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
)
|
||||
)
|
||||
logger.info("SwinIR output shape: %s", dest.shape)
|
||||
|
||||
|
|
|
@ -4,13 +4,14 @@ from typing import Any, List, Optional
|
|||
from PIL import Image
|
||||
|
||||
from ..chain import (
|
||||
blend_img2img,
|
||||
blend_mask,
|
||||
source_txt2img,
|
||||
upscale_highres,
|
||||
upscale_outpaint,
|
||||
BlendImg2ImgStage,
|
||||
BlendMaskStage,
|
||||
ChainPipeline,
|
||||
SourceTxt2ImgStage,
|
||||
UpscaleHighresStage,
|
||||
UpscaleOutpaintStage,
|
||||
)
|
||||
from ..chain.base import ChainPipeline
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
from ..output import save_image
|
||||
from ..params import (
|
||||
Border,
|
||||
|
@ -24,7 +25,6 @@ from ..server import ServerContext
|
|||
from ..server.load import get_source_filters
|
||||
from ..utils import run_gc, show_system_toast
|
||||
from ..worker import WorkerContext
|
||||
from .upscale import split_upscale, stage_upscale_correction
|
||||
from .utils import parse_prompt
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -43,7 +43,7 @@ def run_txt2img_pipeline(
|
|||
chain = ChainPipeline()
|
||||
stage = StageParams()
|
||||
chain.stage(
|
||||
source_txt2img,
|
||||
SourceTxt2ImgStage(),
|
||||
stage,
|
||||
size=size,
|
||||
)
|
||||
|
@ -61,7 +61,7 @@ def run_txt2img_pipeline(
|
|||
# apply highres
|
||||
for _i in range(highres.iterations):
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
UpscaleHighresStage(),
|
||||
StageParams(
|
||||
outscale=highres.scale,
|
||||
),
|
||||
|
@ -125,7 +125,7 @@ def run_img2img_pipeline(
|
|||
chain = ChainPipeline()
|
||||
stage = StageParams()
|
||||
chain.stage(
|
||||
blend_img2img,
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
strength=strength,
|
||||
)
|
||||
|
@ -144,7 +144,7 @@ def run_img2img_pipeline(
|
|||
if params.loopback > 0:
|
||||
for _i in range(params.loopback):
|
||||
chain.stage(
|
||||
blend_img2img,
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
strength=strength,
|
||||
)
|
||||
|
@ -153,7 +153,7 @@ def run_img2img_pipeline(
|
|||
if highres.iterations > 0:
|
||||
for _i in range(highres.iterations):
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
UpscaleHighresStage(),
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
|
@ -223,7 +223,7 @@ def run_inpaint_pipeline(
|
|||
chain = ChainPipeline()
|
||||
stage = StageParams(tile_order=tile_order)
|
||||
chain.stage(
|
||||
upscale_outpaint,
|
||||
UpscaleOutpaintStage(),
|
||||
stage,
|
||||
border=border,
|
||||
stage_mask=mask,
|
||||
|
@ -234,7 +234,7 @@ def run_inpaint_pipeline(
|
|||
|
||||
# apply highres
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
UpscaleHighresStage(),
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
|
@ -300,7 +300,7 @@ def run_upscale_pipeline(
|
|||
|
||||
# apply highres
|
||||
chain.stage(
|
||||
upscale_highres,
|
||||
UpscaleHighresStage(),
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
|
@ -353,7 +353,7 @@ def run_blend_pipeline(
|
|||
# set up the chain pipeline and base stage
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams()
|
||||
chain.stage(blend_mask, stage, stage_source=sources[1], stage_mask=mask)
|
||||
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
|
||||
|
||||
# apply upscaling and correction
|
||||
stage_upscale_correction(
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
from typing import Tuple, Union
|
||||
|
||||
from PIL import Image, ImageChops, ImageOps
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from ..params import Border, Size
|
||||
from .mask_filter import mask_filter_none
|
||||
|
|
|
@ -356,12 +356,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
pipeline = ChainPipeline()
|
||||
for stage_data in data.get("stages", []):
|
||||
callback = CHAIN_STAGES[stage_data.get("type")]
|
||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||
kwargs = stage_data.get("params", {})
|
||||
logger.info("request stage: %s, %s", callback.__name__, kwargs)
|
||||
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
||||
|
||||
stage = StageParams(
|
||||
stage_data.get("name", callback.__name__),
|
||||
stage_data.get("name", stage_class.__name__),
|
||||
tile_size=get_size(kwargs.get("tile_size")),
|
||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||
)
|
||||
|
@ -399,7 +399,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||
kwargs["stage_mask"] = mask
|
||||
|
||||
pipeline.append((callback, stage, kwargs))
|
||||
pipeline.append((stage_class(), stage, kwargs))
|
||||
|
||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
|
|
Loading…
Reference in New Issue