1
0
Fork 0

feat: add parameter to run correction before upscaling (#132)

This commit is contained in:
Sean Sube 2023-02-18 10:59:59 -06:00
parent 4d62404970
commit ecf3c03f0f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 68 additions and 11 deletions

View File

@ -72,7 +72,8 @@ class ChainPipeline:
""" """
Append an additional stage to this pipeline. Append an additional stage to this pipeline.
""" """
self.stages.append(stage) if stage is not None:
self.stages.append(stage)
def __call__( def __call__(
self, self,

View File

@ -22,6 +22,12 @@ class TileOrder:
spiral = "spiral" spiral = "spiral"
class UpscaleOrder:
correct_first = "correct-first"
correct_last = "correct-last"
correct_both = "correct-both"
Param = Union[str, int, float] Param = Union[str, int, float]
Point = Tuple[int, int] Point = Tuple[int, int]
@ -168,6 +174,7 @@ class UpscaleParams:
scale: int = 4, scale: int = 4,
pre_pad: int = 0, pre_pad: int = 0,
tile_pad: int = 10, tile_pad: int = 10,
upscale_order: Literal["correction-first", "correction-last", "correction-both"] = "upscaling-first",
) -> None: ) -> None:
self.upscale_model = upscale_model self.upscale_model = upscale_model
self.correction_model = correction_model self.correction_model = correction_model
@ -181,6 +188,7 @@ class UpscaleParams:
self.pre_pad = pre_pad self.pre_pad = pre_pad
self.scale = scale self.scale = scale
self.tile_pad = tile_pad self.tile_pad = tile_pad
self.upscale_order = upscale_order
def rescale(self, scale: int): def rescale(self, scale: int):
return UpscaleParams( return UpscaleParams(
@ -196,6 +204,7 @@ class UpscaleParams:
scale=scale, scale=scale,
pre_pad=self.pre_pad, pre_pad=self.pre_pad,
tile_pad=self.tile_pad, tile_pad=self.tile_pad,
upscale_order=self.upscale_order,
) )
def resize(self, size: Size) -> Size: def resize(self, size: Size) -> Size:
@ -218,4 +227,5 @@ class UpscaleParams:
"pre_pad": self.pre_pad, "pre_pad": self.pre_pad,
"scale": self.scale, "scale": self.scale,
"tile_pad": self.tile_pad, "tile_pad": self.tile_pad,
"upscale_order": self.upscale_order,
} }

View File

@ -255,6 +255,7 @@ def upscale_from_request() -> UpscaleParams:
faces = get_not_empty(request.args, "faces", "false") == "true" faces = get_not_empty(request.args, "faces", "false") == "true"
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
upscale_order = request.args.get("upscaleOrder", "correction-first")
return UpscaleParams( return UpscaleParams(
upscaling, upscaling,
@ -266,10 +267,11 @@ def upscale_from_request() -> UpscaleParams:
format="onnx", format="onnx",
outscale=outscale, outscale=outscale,
scale=scale, scale=scale,
upscale_order=upscale_order,
) )
def check_paths(context: ServerContext): def check_paths(context: ServerContext) -> None:
if not path.exists(context.model_path): if not path.exists(context.model_path):
raise RuntimeError("model path must exist") raise RuntimeError("model path must exist")
@ -283,7 +285,7 @@ def get_model_name(model: str) -> str:
return file return file
def load_models(context: ServerContext): def load_models(context: ServerContext) -> None:
global diffusion_models global diffusion_models
global correction_models global correction_models
global upscaling_models global upscaling_models
@ -313,7 +315,7 @@ def load_models(context: ServerContext):
upscaling_models.sort() upscaling_models.sort()
def load_params(context: ServerContext): def load_params(context: ServerContext) -> None:
global config_params global config_params
params_file = path.join(context.params_path, "params.json") params_file = path.join(context.params_path, "params.json")
with open(params_file, "r") as f: with open(params_file, "r") as f:
@ -328,7 +330,7 @@ def load_params(context: ServerContext):
config_platform["default"] = context.default_platform config_platform["default"] = context.default_platform
def load_platforms(context: ServerContext): def load_platforms(context: ServerContext) -> None:
global available_platforms global available_platforms
providers = list(get_available_providers()) providers = list(get_available_providers())

View File

@ -36,27 +36,42 @@ def run_upscale_correction(
if upscale.scale > 1: if upscale.scale > 1:
if "esrgan" in upscale.upscale_model: if "esrgan" in upscale.upscale_model:
esrgan_stage = StageParams( esrgan_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.outscale tile_size=stage.tile_size, outscale=upscale.outscale
) )
chain.append((upscale_resrgan, esrgan_stage, None)) upscale_stage = (upscale_resrgan, esrgan_params, None)
elif "stable-diffusion" in upscale.upscale_model: elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size) mini_tile = min(SizeChart.mini, stage.tile_size)
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, sd_stage, None)) upscale_stage = (upscale_stable_diffusion, sd_stage, None)
else: else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model) logger.warn("unknown upscaling model: %s", upscale.upscale_model)
upscale_stage = None
if upscale.faces: if upscale.faces:
face_stage = StageParams( face_stage = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale tile_size=stage.tile_size, outscale=upscale.face_outscale
) )
if "codeformer" in upscale.correction_model: if "codeformer" in upscale.correction_model:
chain.append((correct_codeformer, face_stage, None)) correct_stage = (correct_codeformer, face_stage, None)
elif "gfpgan" in upscale.correction_model: elif "gfpgan" in upscale.correction_model:
chain.append((correct_gfpgan, face_stage, None)) correct_stage = (correct_gfpgan, face_stage, None)
else: else:
logger.warn("unknown correction model: %s", upscale.correction_model) logger.warn("unknown correction model: %s", upscale.correction_model)
correct_stage = None
if upscale.upscale_order == "correction-both":
chain.append(correct_stage)
chain.append(upscale_stage)
chain.append(correct_stage)
elif upscale.upscale_order == "correction-first":
chain.append(correct_stage)
chain.append(upscale_stage)
elif upscale.upscale_order == "correction-last":
chain.append(upscale_stage)
chain.append(correct_stage)
else:
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
return chain( return chain(
job, job,

View File

@ -127,6 +127,14 @@
"max": 512, "max": 512,
"step": 8 "step": 8
}, },
"upscaleOrder": {
"default": "correction-last",
"keys": [
"correction-both",
"correction-first",
"correction-last"
]
},
"upscaling": { "upscaling": {
"default": "", "default": "",
"keys": [] "keys": []

View File

@ -114,6 +114,8 @@ export interface BrushParams {
*/ */
export interface UpscaleParams { export interface UpscaleParams {
enabled: boolean; enabled: boolean;
upscaleOrder: string;
denoise: number; denoise: number;
scale: number; scale: number;
outscale: number; outscale: number;

View File

@ -1,5 +1,6 @@
import { mustExist } from '@apextoaster/js-utils'; import { mustExist } from '@apextoaster/js-utils';
import { Checkbox, FormControlLabel, Stack } from '@mui/material'; import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import { startCase } from 'lodash';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react'; import { useContext } from 'react';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
@ -106,5 +107,22 @@ export function UpscaleControl() {
}); });
}} }}
/> />
<FormControl>
<InputLabel id={'upscale-order'}>Upscale Order</InputLabel>
<Select
labelId={'upscale-order'}
label={'Upscale Order'}
value={upscale.upscaleOrder}
onChange={(e) => {
setUpscale({
upscaleOrder: e.target.value,
});
}}
>
{params.upscaleOrder.keys.map((name) =>
<MenuItem key={name} value={name}>{startCase(name)}</MenuItem>)
}
</Select>
</FormControl>
</Stack>; </Stack>;
} }

View File

@ -408,6 +408,7 @@ export function createStateSlices(server: ServerParams) {
faceStrength: server.faceStrength.default, faceStrength: server.faceStrength.default,
outscale: server.outscale.default, outscale: server.outscale.default,
scale: server.scale.default, scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
}, },
upscaleTab: { upscaleTab: {
source: null, source: null,