feat: add parameter to run correction before upscaling (#132)
This commit is contained in:
parent
4d62404970
commit
ecf3c03f0f
|
@ -72,7 +72,8 @@ class ChainPipeline:
|
||||||
"""
|
"""
|
||||||
Append an additional stage to this pipeline.
|
Append an additional stage to this pipeline.
|
||||||
"""
|
"""
|
||||||
self.stages.append(stage)
|
if stage is not None:
|
||||||
|
self.stages.append(stage)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -22,6 +22,12 @@ class TileOrder:
|
||||||
spiral = "spiral"
|
spiral = "spiral"
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleOrder:
|
||||||
|
correct_first = "correct-first"
|
||||||
|
correct_last = "correct-last"
|
||||||
|
correct_both = "correct-both"
|
||||||
|
|
||||||
|
|
||||||
Param = Union[str, int, float]
|
Param = Union[str, int, float]
|
||||||
Point = Tuple[int, int]
|
Point = Tuple[int, int]
|
||||||
|
|
||||||
|
@ -168,6 +174,7 @@ class UpscaleParams:
|
||||||
scale: int = 4,
|
scale: int = 4,
|
||||||
pre_pad: int = 0,
|
pre_pad: int = 0,
|
||||||
tile_pad: int = 10,
|
tile_pad: int = 10,
|
||||||
|
upscale_order: Literal["correction-first", "correction-last", "correction-both"] = "upscaling-first",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.upscale_model = upscale_model
|
self.upscale_model = upscale_model
|
||||||
self.correction_model = correction_model
|
self.correction_model = correction_model
|
||||||
|
@ -181,6 +188,7 @@ class UpscaleParams:
|
||||||
self.pre_pad = pre_pad
|
self.pre_pad = pre_pad
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.tile_pad = tile_pad
|
self.tile_pad = tile_pad
|
||||||
|
self.upscale_order = upscale_order
|
||||||
|
|
||||||
def rescale(self, scale: int):
|
def rescale(self, scale: int):
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
|
@ -196,6 +204,7 @@ class UpscaleParams:
|
||||||
scale=scale,
|
scale=scale,
|
||||||
pre_pad=self.pre_pad,
|
pre_pad=self.pre_pad,
|
||||||
tile_pad=self.tile_pad,
|
tile_pad=self.tile_pad,
|
||||||
|
upscale_order=self.upscale_order,
|
||||||
)
|
)
|
||||||
|
|
||||||
def resize(self, size: Size) -> Size:
|
def resize(self, size: Size) -> Size:
|
||||||
|
@ -218,4 +227,5 @@ class UpscaleParams:
|
||||||
"pre_pad": self.pre_pad,
|
"pre_pad": self.pre_pad,
|
||||||
"scale": self.scale,
|
"scale": self.scale,
|
||||||
"tile_pad": self.tile_pad,
|
"tile_pad": self.tile_pad,
|
||||||
|
"upscale_order": self.upscale_order,
|
||||||
}
|
}
|
||||||
|
|
|
@ -255,6 +255,7 @@ def upscale_from_request() -> UpscaleParams:
|
||||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||||
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
|
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
|
||||||
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
||||||
|
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
||||||
|
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
upscaling,
|
upscaling,
|
||||||
|
@ -266,10 +267,11 @@ def upscale_from_request() -> UpscaleParams:
|
||||||
format="onnx",
|
format="onnx",
|
||||||
outscale=outscale,
|
outscale=outscale,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
upscale_order=upscale_order,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_paths(context: ServerContext):
|
def check_paths(context: ServerContext) -> None:
|
||||||
if not path.exists(context.model_path):
|
if not path.exists(context.model_path):
|
||||||
raise RuntimeError("model path must exist")
|
raise RuntimeError("model path must exist")
|
||||||
|
|
||||||
|
@ -283,7 +285,7 @@ def get_model_name(model: str) -> str:
|
||||||
return file
|
return file
|
||||||
|
|
||||||
|
|
||||||
def load_models(context: ServerContext):
|
def load_models(context: ServerContext) -> None:
|
||||||
global diffusion_models
|
global diffusion_models
|
||||||
global correction_models
|
global correction_models
|
||||||
global upscaling_models
|
global upscaling_models
|
||||||
|
@ -313,7 +315,7 @@ def load_models(context: ServerContext):
|
||||||
upscaling_models.sort()
|
upscaling_models.sort()
|
||||||
|
|
||||||
|
|
||||||
def load_params(context: ServerContext):
|
def load_params(context: ServerContext) -> None:
|
||||||
global config_params
|
global config_params
|
||||||
params_file = path.join(context.params_path, "params.json")
|
params_file = path.join(context.params_path, "params.json")
|
||||||
with open(params_file, "r") as f:
|
with open(params_file, "r") as f:
|
||||||
|
@ -328,7 +330,7 @@ def load_params(context: ServerContext):
|
||||||
config_platform["default"] = context.default_platform
|
config_platform["default"] = context.default_platform
|
||||||
|
|
||||||
|
|
||||||
def load_platforms(context: ServerContext):
|
def load_platforms(context: ServerContext) -> None:
|
||||||
global available_platforms
|
global available_platforms
|
||||||
|
|
||||||
providers = list(get_available_providers())
|
providers = list(get_available_providers())
|
||||||
|
|
|
@ -36,27 +36,42 @@ def run_upscale_correction(
|
||||||
|
|
||||||
if upscale.scale > 1:
|
if upscale.scale > 1:
|
||||||
if "esrgan" in upscale.upscale_model:
|
if "esrgan" in upscale.upscale_model:
|
||||||
esrgan_stage = StageParams(
|
esrgan_params = StageParams(
|
||||||
tile_size=stage.tile_size, outscale=upscale.outscale
|
tile_size=stage.tile_size, outscale=upscale.outscale
|
||||||
)
|
)
|
||||||
chain.append((upscale_resrgan, esrgan_stage, None))
|
upscale_stage = (upscale_resrgan, esrgan_params, None)
|
||||||
elif "stable-diffusion" in upscale.upscale_model:
|
elif "stable-diffusion" in upscale.upscale_model:
|
||||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||||
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||||
chain.append((upscale_stable_diffusion, sd_stage, None))
|
upscale_stage = (upscale_stable_diffusion, sd_stage, None)
|
||||||
else:
|
else:
|
||||||
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||||
|
upscale_stage = None
|
||||||
|
|
||||||
if upscale.faces:
|
if upscale.faces:
|
||||||
face_stage = StageParams(
|
face_stage = StageParams(
|
||||||
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
||||||
)
|
)
|
||||||
if "codeformer" in upscale.correction_model:
|
if "codeformer" in upscale.correction_model:
|
||||||
chain.append((correct_codeformer, face_stage, None))
|
correct_stage = (correct_codeformer, face_stage, None)
|
||||||
elif "gfpgan" in upscale.correction_model:
|
elif "gfpgan" in upscale.correction_model:
|
||||||
chain.append((correct_gfpgan, face_stage, None))
|
correct_stage = (correct_gfpgan, face_stage, None)
|
||||||
else:
|
else:
|
||||||
logger.warn("unknown correction model: %s", upscale.correction_model)
|
logger.warn("unknown correction model: %s", upscale.correction_model)
|
||||||
|
correct_stage = None
|
||||||
|
|
||||||
|
if upscale.upscale_order == "correction-both":
|
||||||
|
chain.append(correct_stage)
|
||||||
|
chain.append(upscale_stage)
|
||||||
|
chain.append(correct_stage)
|
||||||
|
elif upscale.upscale_order == "correction-first":
|
||||||
|
chain.append(correct_stage)
|
||||||
|
chain.append(upscale_stage)
|
||||||
|
elif upscale.upscale_order == "correction-last":
|
||||||
|
chain.append(upscale_stage)
|
||||||
|
chain.append(correct_stage)
|
||||||
|
else:
|
||||||
|
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
|
||||||
|
|
||||||
return chain(
|
return chain(
|
||||||
job,
|
job,
|
||||||
|
|
|
@ -127,6 +127,14 @@
|
||||||
"max": 512,
|
"max": 512,
|
||||||
"step": 8
|
"step": 8
|
||||||
},
|
},
|
||||||
|
"upscaleOrder": {
|
||||||
|
"default": "correction-last",
|
||||||
|
"keys": [
|
||||||
|
"correction-both",
|
||||||
|
"correction-first",
|
||||||
|
"correction-last"
|
||||||
|
]
|
||||||
|
},
|
||||||
"upscaling": {
|
"upscaling": {
|
||||||
"default": "",
|
"default": "",
|
||||||
"keys": []
|
"keys": []
|
||||||
|
|
|
@ -114,6 +114,8 @@ export interface BrushParams {
|
||||||
*/
|
*/
|
||||||
export interface UpscaleParams {
|
export interface UpscaleParams {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
|
upscaleOrder: string;
|
||||||
|
|
||||||
denoise: number;
|
denoise: number;
|
||||||
scale: number;
|
scale: number;
|
||||||
outscale: number;
|
outscale: number;
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
import { Checkbox, FormControlLabel, Stack } from '@mui/material';
|
import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
|
||||||
|
import { startCase } from 'lodash';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useContext } from 'react';
|
import { useContext } from 'react';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
@ -106,5 +107,22 @@ export function UpscaleControl() {
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
<FormControl>
|
||||||
|
<InputLabel id={'upscale-order'}>Upscale Order</InputLabel>
|
||||||
|
<Select
|
||||||
|
labelId={'upscale-order'}
|
||||||
|
label={'Upscale Order'}
|
||||||
|
value={upscale.upscaleOrder}
|
||||||
|
onChange={(e) => {
|
||||||
|
setUpscale({
|
||||||
|
upscaleOrder: e.target.value,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{params.upscaleOrder.keys.map((name) =>
|
||||||
|
<MenuItem key={name} value={name}>{startCase(name)}</MenuItem>)
|
||||||
|
}
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
</Stack>;
|
</Stack>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -408,6 +408,7 @@ export function createStateSlices(server: ServerParams) {
|
||||||
faceStrength: server.faceStrength.default,
|
faceStrength: server.faceStrength.default,
|
||||||
outscale: server.outscale.default,
|
outscale: server.outscale.default,
|
||||||
scale: server.scale.default,
|
scale: server.scale.default,
|
||||||
|
upscaleOrder: server.upscaleOrder.default,
|
||||||
},
|
},
|
||||||
upscaleTab: {
|
upscaleTab: {
|
||||||
source: null,
|
source: null,
|
||||||
|
|
Loading…
Reference in New Issue