fix(api): remove nested tiling in highres
This commit is contained in:
parent
eef055eddd
commit
a7be651032
|
@ -17,6 +17,7 @@ from .upscale_bsrgan import UpscaleBSRGANStage
|
|||
from .upscale_highres import UpscaleHighresStage
|
||||
from .upscale_outpaint import UpscaleOutpaintStage
|
||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||
from .upscale_simple import UpscaleSimpleStage
|
||||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||
from .upscale_swinir import UpscaleSwinIRStage
|
||||
|
||||
|
@ -39,6 +40,7 @@ CHAIN_STAGES = {
|
|||
"upscale-highres": UpscaleHighresStage,
|
||||
"upscale-outpaint": UpscaleOutpaintStage,
|
||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||
"upscale-simple": UpscaleSimpleStage,
|
||||
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
||||
"upscale-swinir": UpscaleSwinIRStage,
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ from ..server import ServerContext
|
|||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .utils import process_tile_order
|
||||
from .tile import process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .utils import process_tile_order
|
||||
from .tile import process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from ..chain.base import ChainPipeline
|
||||
from ..chain.blend_img2img import BlendImg2ImgStage
|
||||
from ..chain.upscale import stage_upscale_correction
|
||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def stage_highres(
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
highres: HighresParams,
|
||||
upscale: UpscaleParams,
|
||||
chain: Optional[ChainPipeline] = None,
|
||||
) -> ChainPipeline:
|
||||
logger.info("staging highres pipeline at %s", highres.scale)
|
||||
|
||||
if chain is None:
|
||||
chain = ChainPipeline()
|
||||
|
||||
if highres.iterations < 1:
|
||||
logger.debug("no highres iterations, skipping")
|
||||
return chain
|
||||
|
||||
if highres.method == "upscale":
|
||||
logger.debug("using upscaling pipeline for highres")
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
params,
|
||||
upscale=upscale.with_args(
|
||||
faces=False,
|
||||
scale=highres.scale,
|
||||
outscale=highres.scale,
|
||||
),
|
||||
chain=chain,
|
||||
overlap=params.overlap,
|
||||
)
|
||||
else:
|
||||
logger.debug("using simple upscaling for highres")
|
||||
chain.stage(
|
||||
UpscaleSimpleStage(),
|
||||
stage,
|
||||
overlap=params.overlap,
|
||||
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
|
||||
)
|
||||
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
overlap=params.overlap,
|
||||
strength=highres.strength,
|
||||
)
|
||||
|
||||
return chain
|
|
@ -44,13 +44,14 @@ def stage_upscale_correction(
|
|||
chain: Optional[ChainPipeline] = None,
|
||||
pre_stages: List[PipelineStage] = None,
|
||||
post_stages: List[PipelineStage] = None,
|
||||
**kwargs,
|
||||
) -> ChainPipeline:
|
||||
"""
|
||||
This is a convenience method for a chain pipeline that will run upscaling and
|
||||
correction, based on the `upscale` params.
|
||||
"""
|
||||
logger.info(
|
||||
"running upscaling and correction pipeline at %s:%s",
|
||||
"staging upscaling and correction pipeline at %s:%s",
|
||||
upscale.scale,
|
||||
upscale.outscale,
|
||||
)
|
||||
|
@ -63,6 +64,7 @@ def stage_upscale_correction(
|
|||
chain.append((stage, pre_params, pre_opts))
|
||||
|
||||
upscale_opts = {
|
||||
**kwargs,
|
||||
"upscale": upscale,
|
||||
}
|
||||
upscale_stage = None
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
from logging import getLogger
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..chain import BlendImg2ImgStage, ChainPipeline
|
||||
from ..chain.highres import stage_highres
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker.context import ProgressCallback
|
||||
from .upscale import stage_upscale_correction
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,14 +17,13 @@ class UpscaleHighresStage:
|
|||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
highres: HighresParams,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
pipeline: Optional[Any] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
|
@ -34,35 +32,7 @@ class UpscaleHighresStage:
|
|||
if highres.scale <= 1:
|
||||
return source
|
||||
|
||||
chain = ChainPipeline()
|
||||
scaled_size = (source.width * highres.scale, source.height * highres.scale)
|
||||
|
||||
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
|
||||
if highres.method == "bilinear":
|
||||
logger.debug("using bilinear interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||
elif highres.method == "lanczos":
|
||||
logger.debug("using Lanczos interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
logger.debug("using upscaling pipeline for highres")
|
||||
stage_upscale_correction(
|
||||
StageParams(),
|
||||
params,
|
||||
upscale=upscale.with_args(
|
||||
faces=False,
|
||||
scale=highres.scale,
|
||||
outscale=highres.scale,
|
||||
),
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
StageParams(),
|
||||
overlap=params.overlap,
|
||||
strength=highres.strength,
|
||||
)
|
||||
chain = stage_highres(stage, params, highres, upscale)
|
||||
|
||||
return chain(
|
||||
job,
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .utils import complete_tile, process_tile_grid, process_tile_order
|
||||
from .tile import complete_tile, process_tile_grid, process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UpscaleSimpleStage:
|
||||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
method: str,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
||||
if upscale.scale <= 1:
|
||||
logger.debug(
|
||||
"simple upscale stage run with scale of %s, skipping", upscale.scale
|
||||
)
|
||||
return source
|
||||
|
||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||
|
||||
if method == "bilinear":
|
||||
logger.debug("using bilinear interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||
elif method == "lanczos":
|
||||
logger.debug("using Lanczos interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
logger.warning("unknown upscaling method: %s", method)
|
||||
|
||||
return source
|
|
@ -3,12 +3,13 @@ from typing import Any, List, Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.highres import stage_highres
|
||||
|
||||
from ..chain import (
|
||||
BlendImg2ImgStage,
|
||||
BlendMaskStage,
|
||||
ChainPipeline,
|
||||
SourceTxt2ImgStage,
|
||||
UpscaleHighresStage,
|
||||
UpscaleOutpaintStage,
|
||||
)
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
|
@ -60,14 +61,12 @@ def run_txt2img_pipeline(
|
|||
|
||||
# apply highres
|
||||
for _i in range(highres.iterations):
|
||||
chain.stage(
|
||||
UpscaleHighresStage(),
|
||||
StageParams(
|
||||
outscale=highres.scale,
|
||||
),
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
overlap=params.overlap,
|
||||
stage_highres(
|
||||
stage,
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, after highres
|
||||
|
@ -141,23 +140,22 @@ def run_img2img_pipeline(
|
|||
)
|
||||
|
||||
# loopback through multiple img2img iterations
|
||||
if params.loopback > 0:
|
||||
for _i in range(params.loopback):
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
strength=strength,
|
||||
)
|
||||
for _i in range(params.loopback):
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
# highres, if selected
|
||||
if highres.iterations > 0:
|
||||
for _i in range(highres.iterations):
|
||||
chain.stage(
|
||||
UpscaleHighresStage(),
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
)
|
||||
for _i in range(highres.iterations):
|
||||
stage_highres(
|
||||
stage,
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, after highres
|
||||
stage_upscale_correction(
|
||||
|
@ -233,12 +231,14 @@ def run_inpaint_pipeline(
|
|||
)
|
||||
|
||||
# apply highres
|
||||
chain.stage(
|
||||
UpscaleHighresStage(),
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
)
|
||||
for _i in range(highres.iterations):
|
||||
stage_highres(
|
||||
stage,
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
# apply upscaling and correction
|
||||
stage_upscale_correction(
|
||||
|
@ -299,12 +299,14 @@ def run_upscale_pipeline(
|
|||
)
|
||||
|
||||
# apply highres
|
||||
chain.stage(
|
||||
UpscaleHighresStage(),
|
||||
stage,
|
||||
highres=highres,
|
||||
upscale=upscale,
|
||||
)
|
||||
for _i in range(highres.iterations):
|
||||
stage_highres(
|
||||
stage,
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, after highres
|
||||
stage_upscale_correction(
|
||||
|
|
Loading…
Reference in New Issue