1
0
Fork 0

fix(api): remove nested tiling in highres

This commit is contained in:
Sean Sube 2023-07-02 12:16:13 -05:00
parent eef055eddd
commit a7be651032
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 154 additions and 74 deletions

View File

@ -17,6 +17,7 @@ from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage from .upscale_swinir import UpscaleSwinIRStage
@ -39,6 +40,7 @@ CHAIN_STAGES = {
"upscale-highres": UpscaleHighresStage, "upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage, "upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage, "upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage, "upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage, "upscale-swinir": UpscaleSwinIRStage,
} }

View File

@ -11,7 +11,7 @@ from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .stage import BaseStage
from .utils import process_tile_order from .tile import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -13,7 +13,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order from .tile import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -0,0 +1,58 @@
from logging import getLogger
from typing import Optional
from ..chain.base import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
logger = getLogger(__name__)
def stage_highres(
stage: StageParams,
params: ImageParams,
highres: HighresParams,
upscale: UpscaleParams,
chain: Optional[ChainPipeline] = None,
) -> ChainPipeline:
logger.info("staging highres pipeline at %s", highres.scale)
if chain is None:
chain = ChainPipeline()
if highres.iterations < 1:
logger.debug("no highres iterations, skipping")
return chain
if highres.method == "upscale":
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
stage,
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
overlap=params.overlap,
)
else:
logger.debug("using simple upscaling for highres")
chain.stage(
UpscaleSimpleStage(),
stage,
overlap=params.overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
)
chain.stage(
BlendImg2ImgStage(),
stage,
overlap=params.overlap,
strength=highres.strength,
)
return chain

View File

@ -44,13 +44,14 @@ def stage_upscale_correction(
chain: Optional[ChainPipeline] = None, chain: Optional[ChainPipeline] = None,
pre_stages: List[PipelineStage] = None, pre_stages: List[PipelineStage] = None,
post_stages: List[PipelineStage] = None, post_stages: List[PipelineStage] = None,
**kwargs,
) -> ChainPipeline: ) -> ChainPipeline:
""" """
This is a convenience method for a chain pipeline that will run upscaling and This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params. correction, based on the `upscale` params.
""" """
logger.info( logger.info(
"running upscaling and correction pipeline at %s:%s", "staging upscaling and correction pipeline at %s:%s",
upscale.scale, upscale.scale,
upscale.outscale, upscale.outscale,
) )
@ -63,6 +64,7 @@ def stage_upscale_correction(
chain.append((stage, pre_params, pre_opts)) chain.append((stage, pre_params, pre_opts))
upscale_opts = { upscale_opts = {
**kwargs,
"upscale": upscale, "upscale": upscale,
} }
upscale_stage = None upscale_stage = None

View File

@ -1,14 +1,13 @@
from logging import getLogger from logging import getLogger
from typing import Any, Optional from typing import Optional
from PIL import Image from PIL import Image
from ..chain import BlendImg2ImgStage, ChainPipeline from ..chain.highres import stage_highres
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from ..worker.context import ProgressCallback from ..worker.context import ProgressCallback
from .upscale import stage_upscale_correction
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,14 +17,13 @@ class UpscaleHighresStage:
self, self,
job: WorkerContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,
*, *,
highres: HighresParams, highres: HighresParams,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
pipeline: Optional[Any] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
@ -34,35 +32,7 @@ class UpscaleHighresStage:
if highres.scale <= 1: if highres.scale <= 1:
return source return source
chain = ChainPipeline() chain = stage_highres(stage, params, highres, upscale)
scaled_size = (source.width * highres.scale, source.height * highres.scale)
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
)
chain.stage(
BlendImg2ImgStage(),
StageParams(),
overlap=params.overlap,
strength=highres.strength,
)
return chain( return chain(
job, job,

View File

@ -13,7 +13,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .utils import complete_tile, process_tile_grid, process_tile_order from .tile import complete_tile, process_tile_grid, process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -0,0 +1,46 @@
from logging import getLogger
from typing import Optional
from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
logger = getLogger(__name__)
class UpscaleSimpleStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
method: str,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
if upscale.scale <= 1:
logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale
)
return source
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.warning("unknown upscaling method: %s", method)
return source

View File

@ -3,12 +3,13 @@ from typing import Any, List, Optional
from PIL import Image from PIL import Image
from onnx_web.chain.highres import stage_highres
from ..chain import ( from ..chain import (
BlendImg2ImgStage, BlendImg2ImgStage,
BlendMaskStage, BlendMaskStage,
ChainPipeline, ChainPipeline,
SourceTxt2ImgStage, SourceTxt2ImgStage,
UpscaleHighresStage,
UpscaleOutpaintStage, UpscaleOutpaintStage,
) )
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
@ -60,14 +61,12 @@ def run_txt2img_pipeline(
# apply highres # apply highres
for _i in range(highres.iterations): for _i in range(highres.iterations):
chain.stage( stage_highres(
UpscaleHighresStage(), stage,
StageParams( params,
outscale=highres.scale, highres,
), upscale,
highres=highres, chain=chain,
upscale=upscale,
overlap=params.overlap,
) )
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
@ -141,7 +140,6 @@ def run_img2img_pipeline(
) )
# loopback through multiple img2img iterations # loopback through multiple img2img iterations
if params.loopback > 0:
for _i in range(params.loopback): for _i in range(params.loopback):
chain.stage( chain.stage(
BlendImg2ImgStage(), BlendImg2ImgStage(),
@ -150,13 +148,13 @@ def run_img2img_pipeline(
) )
# highres, if selected # highres, if selected
if highres.iterations > 0:
for _i in range(highres.iterations): for _i in range(highres.iterations):
chain.stage( stage_highres(
UpscaleHighresStage(),
stage, stage,
highres=highres, params,
upscale=upscale, highres,
upscale,
chain=chain,
) )
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
@ -233,11 +231,13 @@ def run_inpaint_pipeline(
) )
# apply highres # apply highres
chain.stage( for _i in range(highres.iterations):
UpscaleHighresStage(), stage_highres(
stage, stage,
highres=highres, params,
upscale=upscale, highres,
upscale,
chain=chain,
) )
# apply upscaling and correction # apply upscaling and correction
@ -299,11 +299,13 @@ def run_upscale_pipeline(
) )
# apply highres # apply highres
chain.stage( for _i in range(highres.iterations):
UpscaleHighresStage(), stage_highres(
stage, stage,
highres=highres, params,
upscale=upscale, highres,
upscale,
chain=chain,
) )
# apply upscaling and correction, after highres # apply upscaling and correction, after highres