2023-01-28 23:09:19 +00:00
|
|
|
from logging import getLogger
|
2023-06-30 04:06:36 +00:00
|
|
|
from typing import List, Optional, Tuple
|
2023-01-16 00:13:28 +00:00
|
|
|
|
2023-04-01 16:26:10 +00:00
|
|
|
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
2023-07-01 12:10:53 +00:00
|
|
|
from . import ChainPipeline, PipelineStage
|
|
|
|
from .correct_codeformer import CorrectCodeformerStage
|
|
|
|
from .correct_gfpgan import CorrectGFPGANStage
|
|
|
|
from .upscale_bsrgan import UpscaleBSRGANStage
|
|
|
|
from .upscale_resrgan import UpscaleRealESRGANStage
|
|
|
|
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
|
|
|
from .upscale_swinir import UpscaleSwinIRStage
|
2023-01-16 21:11:40 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-01-16 19:02:15 +00:00
|
|
|
|
2023-06-30 04:06:36 +00:00
|
|
|
def split_upscale(
|
|
|
|
upscale: UpscaleParams,
|
|
|
|
) -> Tuple[Optional[UpscaleParams], UpscaleParams]:
|
|
|
|
if upscale.faces and (
|
|
|
|
upscale.upscale_order == "correction-both"
|
|
|
|
or upscale.upscale_order == "correction-first"
|
|
|
|
):
|
|
|
|
return (
|
|
|
|
upscale.with_args(
|
|
|
|
scale=1,
|
|
|
|
outscale=1,
|
|
|
|
),
|
|
|
|
upscale.with_args(
|
|
|
|
upscale_order="correction-last",
|
|
|
|
),
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return (
|
|
|
|
None,
|
|
|
|
upscale,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-07-01 02:42:24 +00:00
|
|
|
def stage_upscale_correction(
|
2023-01-27 23:08:36 +00:00
|
|
|
stage: StageParams,
|
|
|
|
params: ImageParams,
|
|
|
|
*,
|
|
|
|
upscale: UpscaleParams,
|
2023-06-30 04:06:36 +00:00
|
|
|
chain: Optional[ChainPipeline] = None,
|
2023-07-04 15:20:28 +00:00
|
|
|
pre_stages: Optional[List[PipelineStage]] = None,
|
|
|
|
post_stages: Optional[List[PipelineStage]] = None,
|
2023-07-02 17:16:13 +00:00
|
|
|
**kwargs,
|
2023-06-30 04:06:36 +00:00
|
|
|
) -> ChainPipeline:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-28 18:42:02 +00:00
|
|
|
This is a convenience method for a chain pipeline that will run upscaling and
|
|
|
|
correction, based on the `upscale` params.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-04-01 21:38:08 +00:00
|
|
|
logger.info(
|
2023-07-02 17:16:13 +00:00
|
|
|
"staging upscaling and correction pipeline at %s:%s",
|
2023-04-01 21:38:08 +00:00
|
|
|
upscale.scale,
|
|
|
|
upscale.outscale,
|
|
|
|
)
|
2023-01-26 02:31:39 +00:00
|
|
|
|
2023-06-30 04:06:36 +00:00
|
|
|
if chain is None:
|
|
|
|
chain = ChainPipeline()
|
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
if pre_stages is not None:
|
2023-07-04 15:20:28 +00:00
|
|
|
for pre_stage in pre_stages:
|
|
|
|
chain.append(pre_stage)
|
2023-01-27 23:38:21 +00:00
|
|
|
|
2023-06-30 12:20:49 +00:00
|
|
|
upscale_opts = {
|
2023-07-02 17:16:13 +00:00
|
|
|
**kwargs,
|
2023-06-30 12:20:49 +00:00
|
|
|
"upscale": upscale,
|
|
|
|
}
|
2023-07-04 15:20:28 +00:00
|
|
|
upscale_stage: Optional[PipelineStage] = None
|
2023-01-27 23:08:36 +00:00
|
|
|
if upscale.scale > 1:
|
2023-04-10 22:49:56 +00:00
|
|
|
if "bsrgan" in upscale.upscale_model:
|
|
|
|
bsrgan_params = StageParams(
|
|
|
|
tile_size=stage.tile_size,
|
|
|
|
outscale=upscale.outscale,
|
|
|
|
)
|
2023-07-01 12:10:53 +00:00
|
|
|
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
|
2023-04-10 22:49:56 +00:00
|
|
|
elif "esrgan" in upscale.upscale_model:
|
2023-02-18 16:59:59 +00:00
|
|
|
esrgan_params = StageParams(
|
2023-04-10 22:49:56 +00:00
|
|
|
tile_size=stage.tile_size,
|
|
|
|
outscale=upscale.outscale,
|
2023-02-06 23:26:51 +00:00
|
|
|
)
|
2023-07-01 12:10:53 +00:00
|
|
|
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
|
2023-02-05 13:53:26 +00:00
|
|
|
elif "stable-diffusion" in upscale.upscale_model:
|
2023-01-28 20:56:06 +00:00
|
|
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
2023-02-18 17:59:39 +00:00
|
|
|
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
2023-07-01 12:10:53 +00:00
|
|
|
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
|
2023-04-10 22:57:42 +00:00
|
|
|
elif "swinir" in upscale.upscale_model:
|
2023-04-10 22:49:56 +00:00
|
|
|
swinir_params = StageParams(
|
|
|
|
tile_size=stage.tile_size,
|
|
|
|
outscale=upscale.outscale,
|
|
|
|
)
|
2023-07-01 12:10:53 +00:00
|
|
|
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
|
2023-02-05 16:49:20 +00:00
|
|
|
else:
|
2023-07-16 00:00:52 +00:00
|
|
|
logger.warning("unknown upscaling model: %s", upscale.upscale_model)
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-07-04 15:20:28 +00:00
|
|
|
correct_stage: Optional[PipelineStage] = None
|
2023-01-27 23:08:36 +00:00
|
|
|
if upscale.faces:
|
2023-02-18 17:59:39 +00:00
|
|
|
face_params = StageParams(
|
2023-02-06 23:26:51 +00:00
|
|
|
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
|
|
|
)
|
2023-07-04 15:20:28 +00:00
|
|
|
if upscale.correction_model is None:
|
2023-07-16 00:00:52 +00:00
|
|
|
logger.warning("no correction model set, skipping")
|
2023-07-04 15:20:28 +00:00
|
|
|
elif "codeformer" in upscale.correction_model:
|
2023-07-01 12:10:53 +00:00
|
|
|
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
|
2023-02-05 16:49:20 +00:00
|
|
|
elif "gfpgan" in upscale.correction_model:
|
2023-07-01 12:10:53 +00:00
|
|
|
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
|
2023-02-05 16:49:20 +00:00
|
|
|
else:
|
2023-07-16 00:00:52 +00:00
|
|
|
logger.warning("unknown correction model: %s", upscale.correction_model)
|
2023-02-18 16:59:59 +00:00
|
|
|
|
|
|
|
if upscale.upscale_order == "correction-both":
|
|
|
|
chain.append(correct_stage)
|
|
|
|
chain.append(upscale_stage)
|
|
|
|
chain.append(correct_stage)
|
|
|
|
elif upscale.upscale_order == "correction-first":
|
|
|
|
chain.append(correct_stage)
|
|
|
|
chain.append(upscale_stage)
|
|
|
|
elif upscale.upscale_order == "correction-last":
|
|
|
|
chain.append(upscale_stage)
|
|
|
|
chain.append(correct_stage)
|
|
|
|
else:
|
2023-07-16 00:00:52 +00:00
|
|
|
logger.warning("unknown upscaling order: %s", upscale.upscale_order)
|
2023-01-26 02:31:39 +00:00
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
if post_stages is not None:
|
2023-07-04 15:20:28 +00:00
|
|
|
for post_stage in post_stages:
|
|
|
|
chain.append(post_stage)
|
2023-04-12 00:29:25 +00:00
|
|
|
|
2023-06-30 04:06:36 +00:00
|
|
|
return chain
|