1
0
Fork 0

feat: add parameter to run correction before upscaling (#132)

This commit is contained in:
Sean Sube 2023-02-18 10:59:59 -06:00
parent 4d62404970
commit ecf3c03f0f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 68 additions and 11 deletions

View File

@ -72,6 +72,7 @@ class ChainPipeline:
"""
Append an additional stage to this pipeline.
"""
if stage is not None:
self.stages.append(stage)
def __call__(

View File

@ -22,6 +22,12 @@ class TileOrder:
spiral = "spiral"
class UpscaleOrder:
correct_first = "correct-first"
correct_last = "correct-last"
correct_both = "correct-both"
Param = Union[str, int, float]
Point = Tuple[int, int]
@ -168,6 +174,7 @@ class UpscaleParams:
scale: int = 4,
pre_pad: int = 0,
tile_pad: int = 10,
upscale_order: Literal["correction-first", "correction-last", "correction-both"] = "upscaling-first",
) -> None:
self.upscale_model = upscale_model
self.correction_model = correction_model
@ -181,6 +188,7 @@ class UpscaleParams:
self.pre_pad = pre_pad
self.scale = scale
self.tile_pad = tile_pad
self.upscale_order = upscale_order
def rescale(self, scale: int):
return UpscaleParams(
@ -196,6 +204,7 @@ class UpscaleParams:
scale=scale,
pre_pad=self.pre_pad,
tile_pad=self.tile_pad,
upscale_order=self.upscale_order,
)
def resize(self, size: Size) -> Size:
@ -218,4 +227,5 @@ class UpscaleParams:
"pre_pad": self.pre_pad,
"scale": self.scale,
"tile_pad": self.tile_pad,
"upscale_order": self.upscale_order,
}

View File

@ -255,6 +255,7 @@ def upscale_from_request() -> UpscaleParams:
faces = get_not_empty(request.args, "faces", "false") == "true"
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
upscale_order = request.args.get("upscaleOrder", "correction-first")
return UpscaleParams(
upscaling,
@ -266,10 +267,11 @@ def upscale_from_request() -> UpscaleParams:
format="onnx",
outscale=outscale,
scale=scale,
upscale_order=upscale_order,
)
def check_paths(context: ServerContext):
def check_paths(context: ServerContext) -> None:
if not path.exists(context.model_path):
raise RuntimeError("model path must exist")
@ -283,7 +285,7 @@ def get_model_name(model: str) -> str:
return file
def load_models(context: ServerContext):
def load_models(context: ServerContext) -> None:
global diffusion_models
global correction_models
global upscaling_models
@ -313,7 +315,7 @@ def load_models(context: ServerContext):
upscaling_models.sort()
def load_params(context: ServerContext):
def load_params(context: ServerContext) -> None:
global config_params
params_file = path.join(context.params_path, "params.json")
with open(params_file, "r") as f:
@ -328,7 +330,7 @@ def load_params(context: ServerContext):
config_platform["default"] = context.default_platform
def load_platforms(context: ServerContext):
def load_platforms(context: ServerContext) -> None:
global available_platforms
providers = list(get_available_providers())

View File

@ -36,27 +36,42 @@ def run_upscale_correction(
if upscale.scale > 1:
if "esrgan" in upscale.upscale_model:
esrgan_stage = StageParams(
esrgan_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.outscale
)
chain.append((upscale_resrgan, esrgan_stage, None))
upscale_stage = (upscale_resrgan, esrgan_params, None)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, sd_stage, None))
upscale_stage = (upscale_stable_diffusion, sd_stage, None)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
upscale_stage = None
if upscale.faces:
face_stage = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
chain.append((correct_codeformer, face_stage, None))
correct_stage = (correct_codeformer, face_stage, None)
elif "gfpgan" in upscale.correction_model:
chain.append((correct_gfpgan, face_stage, None))
correct_stage = (correct_gfpgan, face_stage, None)
else:
logger.warn("unknown correction model: %s", upscale.correction_model)
correct_stage = None
if upscale.upscale_order == "correction-both":
chain.append(correct_stage)
chain.append(upscale_stage)
chain.append(correct_stage)
elif upscale.upscale_order == "correction-first":
chain.append(correct_stage)
chain.append(upscale_stage)
elif upscale.upscale_order == "correction-last":
chain.append(upscale_stage)
chain.append(correct_stage)
else:
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
return chain(
job,

View File

@ -127,6 +127,14 @@
"max": 512,
"step": 8
},
"upscaleOrder": {
"default": "correction-last",
"keys": [
"correction-both",
"correction-first",
"correction-last"
]
},
"upscaling": {
"default": "",
"keys": []

View File

@ -114,6 +114,8 @@ export interface BrushParams {
*/
export interface UpscaleParams {
enabled: boolean;
upscaleOrder: string;
denoise: number;
scale: number;
outscale: number;

View File

@ -1,5 +1,6 @@
import { mustExist } from '@apextoaster/js-utils';
import { Checkbox, FormControlLabel, Stack } from '@mui/material';
import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import { startCase } from 'lodash';
import * as React from 'react';
import { useContext } from 'react';
import { useStore } from 'zustand';
@ -106,5 +107,22 @@ export function UpscaleControl() {
});
}}
/>
<FormControl>
<InputLabel id={'upscale-order'}>Upscale Order</InputLabel>
<Select
labelId={'upscale-order'}
label={'Upscale Order'}
value={upscale.upscaleOrder}
onChange={(e) => {
setUpscale({
upscaleOrder: e.target.value,
});
}}
>
{params.upscaleOrder.keys.map((name) =>
<MenuItem key={name} value={name}>{startCase(name)}</MenuItem>)
}
</Select>
</FormControl>
</Stack>;
}

View File

@ -408,6 +408,7 @@ export function createStateSlices(server: ServerParams) {
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
},
upscaleTab: {
source: null,