1
0
Fork 0

fix(api): remove nested tiling in highres

This commit is contained in:
Sean Sube 2023-07-02 12:16:13 -05:00
parent eef055eddd
commit a7be651032
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 154 additions and 74 deletions

View File

@ -17,6 +17,7 @@ from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
@ -39,6 +40,7 @@ CHAIN_STAGES = {
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}

View File

@ -11,7 +11,7 @@ from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .utils import process_tile_order
from .tile import process_tile_order
logger = getLogger(__name__)

View File

@ -13,7 +13,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order
from .tile import process_tile_order
logger = getLogger(__name__)

View File

@ -0,0 +1,58 @@
from logging import getLogger
from typing import Optional
from ..chain.base import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
logger = getLogger(__name__)
def stage_highres(
stage: StageParams,
params: ImageParams,
highres: HighresParams,
upscale: UpscaleParams,
chain: Optional[ChainPipeline] = None,
) -> ChainPipeline:
logger.info("staging highres pipeline at %s", highres.scale)
if chain is None:
chain = ChainPipeline()
if highres.iterations < 1:
logger.debug("no highres iterations, skipping")
return chain
if highres.method == "upscale":
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
stage,
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
overlap=params.overlap,
)
else:
logger.debug("using simple upscaling for highres")
chain.stage(
UpscaleSimpleStage(),
stage,
overlap=params.overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
)
chain.stage(
BlendImg2ImgStage(),
stage,
overlap=params.overlap,
strength=highres.strength,
)
return chain

View File

@ -44,13 +44,14 @@ def stage_upscale_correction(
chain: Optional[ChainPipeline] = None,
pre_stages: List[PipelineStage] = None,
post_stages: List[PipelineStage] = None,
**kwargs,
) -> ChainPipeline:
"""
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
"""
logger.info(
"running upscaling and correction pipeline at %s:%s",
"staging upscaling and correction pipeline at %s:%s",
upscale.scale,
upscale.outscale,
)
@ -63,6 +64,7 @@ def stage_upscale_correction(
chain.append((stage, pre_params, pre_opts))
upscale_opts = {
**kwargs,
"upscale": upscale,
}
upscale_stage = None

View File

@ -1,14 +1,13 @@
from logging import getLogger
from typing import Any, Optional
from typing import Optional
from PIL import Image
from ..chain import BlendImg2ImgStage, ChainPipeline
from ..chain.highres import stage_highres
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .upscale import stage_upscale_correction
logger = getLogger(__name__)
@ -18,14 +17,13 @@ class UpscaleHighresStage:
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
pipeline: Optional[Any] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
@ -34,35 +32,7 @@ class UpscaleHighresStage:
if highres.scale <= 1:
return source
chain = ChainPipeline()
scaled_size = (source.width * highres.scale, source.height * highres.scale)
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
)
chain.stage(
BlendImg2ImgStage(),
StageParams(),
overlap=params.overlap,
strength=highres.strength,
)
chain = stage_highres(stage, params, highres, upscale)
return chain(
job,

View File

@ -13,7 +13,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import complete_tile, process_tile_grid, process_tile_order
from .tile import complete_tile, process_tile_grid, process_tile_order
logger = getLogger(__name__)

View File

@ -0,0 +1,46 @@
from logging import getLogger
from typing import Optional
from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)
class UpscaleSimpleStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
method: str,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
if upscale.scale <= 1:
logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale
)
return source
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.warning("unknown upscaling method: %s", method)
return source

View File

@ -3,12 +3,13 @@ from typing import Any, List, Optional
from PIL import Image
from onnx_web.chain.highres import stage_highres
from ..chain import (
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
SourceTxt2ImgStage,
UpscaleHighresStage,
UpscaleOutpaintStage,
)
from ..chain.upscale import split_upscale, stage_upscale_correction
@ -60,14 +61,12 @@ def run_txt2img_pipeline(
# apply highres
for _i in range(highres.iterations):
chain.stage(
UpscaleHighresStage(),
StageParams(
outscale=highres.scale,
),
highres=highres,
upscale=upscale,
overlap=params.overlap,
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
# apply upscaling and correction, after highres
@ -141,23 +140,22 @@ def run_img2img_pipeline(
)
# loopback through multiple img2img iterations
if params.loopback > 0:
for _i in range(params.loopback):
chain.stage(
BlendImg2ImgStage(),
stage,
strength=strength,
)
for _i in range(params.loopback):
chain.stage(
BlendImg2ImgStage(),
stage,
strength=strength,
)
# highres, if selected
if highres.iterations > 0:
for _i in range(highres.iterations):
chain.stage(
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
)
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
# apply upscaling and correction, after highres
stage_upscale_correction(
@ -233,12 +231,14 @@ def run_inpaint_pipeline(
)
# apply highres
chain.stage(
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
)
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
# apply upscaling and correction
stage_upscale_correction(
@ -299,12 +299,14 @@ def run_upscale_pipeline(
)
# apply highres
chain.stage(
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
)
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
# apply upscaling and correction, after highres
stage_upscale_correction(