feat: add parameter to run correction before upscaling (#132)
This commit is contained in:
parent
4d62404970
commit
ecf3c03f0f
|
@ -72,7 +72,8 @@ class ChainPipeline:
|
|||
"""
|
||||
Append an additional stage to this pipeline.
|
||||
"""
|
||||
self.stages.append(stage)
|
||||
if stage is not None:
|
||||
self.stages.append(stage)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
@ -22,6 +22,12 @@ class TileOrder:
|
|||
spiral = "spiral"
|
||||
|
||||
|
||||
class UpscaleOrder:
|
||||
correct_first = "correct-first"
|
||||
correct_last = "correct-last"
|
||||
correct_both = "correct-both"
|
||||
|
||||
|
||||
Param = Union[str, int, float]
|
||||
Point = Tuple[int, int]
|
||||
|
||||
|
@ -168,6 +174,7 @@ class UpscaleParams:
|
|||
scale: int = 4,
|
||||
pre_pad: int = 0,
|
||||
tile_pad: int = 10,
|
||||
upscale_order: Literal["correction-first", "correction-last", "correction-both"] = "upscaling-first",
|
||||
) -> None:
|
||||
self.upscale_model = upscale_model
|
||||
self.correction_model = correction_model
|
||||
|
@ -181,6 +188,7 @@ class UpscaleParams:
|
|||
self.pre_pad = pre_pad
|
||||
self.scale = scale
|
||||
self.tile_pad = tile_pad
|
||||
self.upscale_order = upscale_order
|
||||
|
||||
def rescale(self, scale: int):
|
||||
return UpscaleParams(
|
||||
|
@ -196,6 +204,7 @@ class UpscaleParams:
|
|||
scale=scale,
|
||||
pre_pad=self.pre_pad,
|
||||
tile_pad=self.tile_pad,
|
||||
upscale_order=self.upscale_order,
|
||||
)
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
|
@ -218,4 +227,5 @@ class UpscaleParams:
|
|||
"pre_pad": self.pre_pad,
|
||||
"scale": self.scale,
|
||||
"tile_pad": self.tile_pad,
|
||||
"upscale_order": self.upscale_order,
|
||||
}
|
||||
|
|
|
@ -255,6 +255,7 @@ def upscale_from_request() -> UpscaleParams:
|
|||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
|
||||
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
||||
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
||||
|
||||
return UpscaleParams(
|
||||
upscaling,
|
||||
|
@ -266,10 +267,11 @@ def upscale_from_request() -> UpscaleParams:
|
|||
format="onnx",
|
||||
outscale=outscale,
|
||||
scale=scale,
|
||||
upscale_order=upscale_order,
|
||||
)
|
||||
|
||||
|
||||
def check_paths(context: ServerContext):
|
||||
def check_paths(context: ServerContext) -> None:
|
||||
if not path.exists(context.model_path):
|
||||
raise RuntimeError("model path must exist")
|
||||
|
||||
|
@ -283,7 +285,7 @@ def get_model_name(model: str) -> str:
|
|||
return file
|
||||
|
||||
|
||||
def load_models(context: ServerContext):
|
||||
def load_models(context: ServerContext) -> None:
|
||||
global diffusion_models
|
||||
global correction_models
|
||||
global upscaling_models
|
||||
|
@ -313,7 +315,7 @@ def load_models(context: ServerContext):
|
|||
upscaling_models.sort()
|
||||
|
||||
|
||||
def load_params(context: ServerContext):
|
||||
def load_params(context: ServerContext) -> None:
|
||||
global config_params
|
||||
params_file = path.join(context.params_path, "params.json")
|
||||
with open(params_file, "r") as f:
|
||||
|
@ -328,7 +330,7 @@ def load_params(context: ServerContext):
|
|||
config_platform["default"] = context.default_platform
|
||||
|
||||
|
||||
def load_platforms(context: ServerContext):
|
||||
def load_platforms(context: ServerContext) -> None:
|
||||
global available_platforms
|
||||
|
||||
providers = list(get_available_providers())
|
||||
|
|
|
@ -36,27 +36,42 @@ def run_upscale_correction(
|
|||
|
||||
if upscale.scale > 1:
|
||||
if "esrgan" in upscale.upscale_model:
|
||||
esrgan_stage = StageParams(
|
||||
esrgan_params = StageParams(
|
||||
tile_size=stage.tile_size, outscale=upscale.outscale
|
||||
)
|
||||
chain.append((upscale_resrgan, esrgan_stage, None))
|
||||
upscale_stage = (upscale_resrgan, esrgan_params, None)
|
||||
elif "stable-diffusion" in upscale.upscale_model:
|
||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
chain.append((upscale_stable_diffusion, sd_stage, None))
|
||||
upscale_stage = (upscale_stable_diffusion, sd_stage, None)
|
||||
else:
|
||||
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||
upscale_stage = None
|
||||
|
||||
if upscale.faces:
|
||||
face_stage = StageParams(
|
||||
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
||||
)
|
||||
if "codeformer" in upscale.correction_model:
|
||||
chain.append((correct_codeformer, face_stage, None))
|
||||
correct_stage = (correct_codeformer, face_stage, None)
|
||||
elif "gfpgan" in upscale.correction_model:
|
||||
chain.append((correct_gfpgan, face_stage, None))
|
||||
correct_stage = (correct_gfpgan, face_stage, None)
|
||||
else:
|
||||
logger.warn("unknown correction model: %s", upscale.correction_model)
|
||||
correct_stage = None
|
||||
|
||||
if upscale.upscale_order == "correction-both":
|
||||
chain.append(correct_stage)
|
||||
chain.append(upscale_stage)
|
||||
chain.append(correct_stage)
|
||||
elif upscale.upscale_order == "correction-first":
|
||||
chain.append(correct_stage)
|
||||
chain.append(upscale_stage)
|
||||
elif upscale.upscale_order == "correction-last":
|
||||
chain.append(upscale_stage)
|
||||
chain.append(correct_stage)
|
||||
else:
|
||||
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
|
||||
|
||||
return chain(
|
||||
job,
|
||||
|
|
|
@ -127,6 +127,14 @@
|
|||
"max": 512,
|
||||
"step": 8
|
||||
},
|
||||
"upscaleOrder": {
|
||||
"default": "correction-last",
|
||||
"keys": [
|
||||
"correction-both",
|
||||
"correction-first",
|
||||
"correction-last"
|
||||
]
|
||||
},
|
||||
"upscaling": {
|
||||
"default": "",
|
||||
"keys": []
|
||||
|
|
|
@ -114,6 +114,8 @@ export interface BrushParams {
|
|||
*/
|
||||
export interface UpscaleParams {
|
||||
enabled: boolean;
|
||||
upscaleOrder: string;
|
||||
|
||||
denoise: number;
|
||||
scale: number;
|
||||
outscale: number;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Checkbox, FormControlLabel, Stack } from '@mui/material';
|
||||
import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
|
||||
import { startCase } from 'lodash';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
@ -106,5 +107,22 @@ export function UpscaleControl() {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<FormControl>
|
||||
<InputLabel id={'upscale-order'}>Upscale Order</InputLabel>
|
||||
<Select
|
||||
labelId={'upscale-order'}
|
||||
label={'Upscale Order'}
|
||||
value={upscale.upscaleOrder}
|
||||
onChange={(e) => {
|
||||
setUpscale({
|
||||
upscaleOrder: e.target.value,
|
||||
});
|
||||
}}
|
||||
>
|
||||
{params.upscaleOrder.keys.map((name) =>
|
||||
<MenuItem key={name} value={name}>{startCase(name)}</MenuItem>)
|
||||
}
|
||||
</Select>
|
||||
</FormControl>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -408,6 +408,7 @@ export function createStateSlices(server: ServerParams) {
|
|||
faceStrength: server.faceStrength.default,
|
||||
outscale: server.outscale.default,
|
||||
scale: server.scale.default,
|
||||
upscaleOrder: server.upscaleOrder.default,
|
||||
},
|
||||
upscaleTab: {
|
||||
source: null,
|
||||
|
|
Loading…
Reference in New Issue