1
0
Fork 0

feat(gui): add highres parameters

This commit is contained in:
Sean Sube 2023-04-01 11:26:10 -05:00
parent f462d80cc4
commit ba09748e94
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
12 changed files with 199 additions and 34 deletions

View File

@ -16,6 +16,7 @@ from .diffusers.run import (
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import run_upscale_correction
from .image import (
expand_image,
mask_filter_gaussian_multiply,
@ -48,7 +49,6 @@ from .server import (
apply_patch_facexlib,
apply_patches,
)
from .upscale import run_upscale_correction
from .utils import (
base_join,
get_and_clamp_float,

View File

@ -11,12 +11,12 @@ from onnx_web.chain.utils import process_tile_order
from ..chain import blend_mask, upscale_outpaint
from ..chain.base import ChainProgress
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, TileOrder, UpscaleParams
from ..params import Border, HighresParams, ImageParams, Size, StageParams, TileOrder, UpscaleParams
from ..server import ServerContext
from ..upscale import run_upscale_correction
from ..utils import run_gc
from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline
from .upscale import run_upscale_correction
from .utils import get_inversions_from_prompt, get_loras_from_prompt
logger = getLogger(__name__)
@ -29,13 +29,8 @@ def run_txt2img_pipeline(
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
) -> None:
# TODO: add to params
highres_scale = 4
highres_steps = 25
highres_strength = 0.2
highres_steps_post = int((params.steps - highres_steps) / highres_strength)
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
(prompt, loras) = get_loras_from_prompt(params.prompt)
@ -66,7 +61,7 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=highres_steps,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)
@ -81,13 +76,13 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=highres_steps,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)
for image, output in zip(result.images, outputs):
if highres_scale > 1:
if highres.scale > 1:
highres_progress = ChainProgress.from_progress(progress)
image = run_upscale_correction(
@ -115,7 +110,7 @@ def run_txt2img_pipeline(
def highres(tile: Image.Image, dims):
tile = tile.resize((size.height, size.width))
if params.lpw:
logger.debug("using LPW pipeline for img2img")
logger.debug("using LPW pipeline for highres")
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
tile,
@ -124,8 +119,8 @@ def run_txt2img_pipeline(
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres_steps_post,
strength=highres_strength,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=highres_progress,
)
@ -139,19 +134,19 @@ def run_txt2img_pipeline(
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres_steps_post,
strength=highres_strength,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=highres_progress,
)
return result.images[0]
logger.info("running highres fix for %s tiles", highres_scale)
logger.info("running highres fix for %s tiles", highres.scale)
image = process_tile_order(
TileOrder.grid,
image,
size.height // highres_scale,
highres_scale,
size.height // highres.scale,
highres.scale,
[highres],
)
@ -166,7 +161,7 @@ def run_txt2img_pipeline(
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
save_params(server, output, params, size, upscale=upscale, highres=highres)
run_gc([job.get_device()])

View File

@ -3,16 +3,16 @@ from typing import Optional
from PIL import Image
from .chain import (
from ..chain import (
ChainPipeline,
correct_codeformer,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .server import ServerContext
from .worker import ProgressCallback, WorkerContext
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)

View File

@ -8,7 +8,7 @@ from typing import Any, List, Optional
from PIL import Image
from .params import Border, ImageParams, Param, Size, UpscaleParams
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
from .server import ServerContext
from .utils import base_join
@ -36,6 +36,7 @@ def json_params(
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> Any:
json = {
"outputs": outputs,
@ -49,6 +50,10 @@ def json_params(
json["border"] = border.tojson()
size = size.add_border(border)
if highres is not None:
json["highres"] = highres.tojson()
size = highres.resize(size)
if upscale is not None:
json["upscale"] = upscale.tojson()
size = upscale.resize(size)
@ -106,9 +111,10 @@ def save_params(
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> str:
path = base_join(ctx.output_path, f"{output}.json")
json = json_params(output, params, size, upscale=upscale, border=border)
json = json_params(output, params, size, upscale=upscale, border=border, highres=highres)
with open(path, "w") as f:
f.write(dumps(json))
logger.debug("saved image params to: %s", path)

View File

@ -317,3 +317,25 @@ class UpscaleParams:
kwargs.get("tile_pad", self.tile_pad),
kwargs.get("upscale_order", self.upscale_order),
)
class HighresParams:
def __init__(
self,
scale: int,
steps: int,
strength: float,
):
self.scale = scale
self.steps = steps
self.strength = strength
def resize(self, size: Size) -> Size:
return Size(size.width * self.scale, size.height * self.scale)
def tojson(self):
return {
"scale": self.scale,
"steps": self.steps,
"strength": self.strength,
}

View File

@ -44,7 +44,7 @@ from .load import (
get_noise_sources,
get_upscaling_models,
)
from .params import border_from_request, pipeline_from_request, upscale_from_request
from .params import border_from_request, highres_from_request, pipeline_from_request, upscale_from_request
from .utils import wrap_route
logger = getLogger(__name__)
@ -174,6 +174,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context)
upscale = upscale_from_request()
highres = highres_from_request()
output = make_output_name(context, "txt2img", params, size)
job_name = output[0]
@ -187,10 +188,11 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
size,
output,
upscale,
highres,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
def inpaint(context: ServerContext, pool: DevicePoolExecutor):

View File

@ -5,7 +5,7 @@ import numpy as np
from flask import request
from ..diffusers.load import pipeline_schedulers
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
from ..params import Border, DeviceParams, ImageParams, HighresParams, Size, UpscaleParams
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
from .context import ServerContext
from .load import (
@ -169,3 +169,15 @@ def upscale_from_request() -> UpscaleParams:
scale=scale,
upscale_order=upscale_order,
)
def highres_from_request() -> HighresParams:
scale = get_and_clamp_int(request.args, "highresScale", 1, 4, 1)
steps = get_and_clamp_int(request.args, "highresSteps", 1, 4, 1)
strength = get_and_clamp_float(request.args, "highresStrength", 0.5, 1.0, 0.0)
return HighresParams(
scale,
steps,
strength,
)

View File

@ -144,6 +144,14 @@ export interface BlendParams {
mask: Blob;
}
export interface HighresParams {
enabled: boolean;
highresScale: number;
highresSteps: number;
highresStrength: number;
}
/**
* Output image data within the response.
*/
@ -201,6 +209,7 @@ export type RetryParams = {
model: ModelParams;
params: Txt2ImgParams;
upscale?: UpscaleParams;
highres?: HighresParams;
} | {
type: 'img2img';
model: ModelParams;
@ -274,7 +283,7 @@ export interface ApiClient {
/**
* Start a txt2img pipeline.
*/
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
/**
* Start an im2img pipeline.
@ -477,7 +486,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
},
};
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);
@ -493,6 +502,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}
if (doesExist(highres) && highres.enabled) {
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER));
url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT));
}
const image = await parseRequest(url, {
method: 'POST',
});

View File

@ -0,0 +1,73 @@
import { mustExist } from '@apextoaster/js-utils';
import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { ConfigContext, StateContext } from '../../state.js';
import { NumericField } from '../input/NumericField.js';
export function UpscaleControl() {
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
const highres = useStore(state, (s) => s.highres);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHighres = useStore(state, (s) => s.setHighres);
const { t } = useTranslation();
return <Stack direction='row' spacing={4}>
<FormControlLabel
label={t('parameter.highres.label')}
control={<Checkbox
checked={highres.enabled}
value='check'
onChange={(event) => {
setHighres({
enabled: highres.enabled === false,
});
}}
/>}
/>
<NumericField
label={t('parameter.highres.steps')}
decimal
disabled={highres.enabled === false}
min={params.denoise.min}
max={params.denoise.max}
step={params.denoise.step}
value={highres.highresSteps}
onChange={(steps) => {
setHighres({
highresSteps: steps,
});
}}
/>
<NumericField
label={t('parameter.highres.scale')}
disabled={highres.enabled === false}
min={params.scale.min}
max={params.scale.max}
step={params.scale.step}
value={highres.highresScale}
onChange={(scale) => {
setHighres({
highresScale: scale,
});
}}
/>
<NumericField
label={t('parameter.highres.strength')}
disabled={highres.enabled === false}
min={params.strength.min}
max={params.strength.max}
step={params.outscale.step}
value={highres.highresStrength}
onChange={(strength) => {
setHighres({
highresStrength: strength,
});
}}
/>
</Stack>;
}

View File

@ -1,6 +1,6 @@
import { doesExist, Maybe } from '@apextoaster/js-utils';
import { merge } from 'lodash';
import { Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client/api.js';
import { HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client/api.js';
export interface ConfigNumber {
default: number;
@ -54,7 +54,8 @@ export type ServerParams = ConfigRanges<Required<
InpaintParams &
ModelParams &
OutpaintParams &
UpscaleParams
UpscaleParams &
HighresParams
>> & {
version: string;
};

View File

@ -55,6 +55,7 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
createOutpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
createHighresSlice,
createBlendSlice,
createResetSlice,
} = createStateSlices(params);
@ -68,6 +69,7 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
...createTxt2ImgSlice(...slice),
...createOutpaintSlice(...slice),
...createUpscaleSlice(...slice),
...createHighresSlice(...slice),
...createBlendSlice(...slice),
...createResetSlice(...slice),
}), {

View File

@ -10,6 +10,7 @@ import {
BaseImgParams,
BlendParams,
BrushParams,
HighresParams,
ImageResponse,
Img2ImgParams,
InpaintParams,
@ -90,6 +91,13 @@ interface OutpaintSlice {
setOutpaint(pixels: Partial<OutpaintPixels>): void;
}
interface HighresSlice {
highres: HighresParams;
setHighres(params: Partial<HighresParams>): void;
resetHighres(): void;
}
interface UpscaleSlice {
upscale: UpscaleParams;
upscaleTab: TabState<UpscaleReqParams>;
@ -123,6 +131,7 @@ export type OnnxState
& ModelSlice
& OutpaintSlice
& Txt2ImgSlice
& HighresSlice
& UpscaleSlice
& BlendSlice
& ResetSlice;
@ -421,6 +430,33 @@ export function createStateSlices(server: ServerParams) {
},
});
const createHighresSlice: Slice<HighresSlice> = (set) => ({
highres: {
enabled: false,
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
},
setHighres(params) {
set((prev) => ({
highres: {
...prev.highres,
...params,
},
}));
},
resetHighres() {
set({
highres: {
enabled: false,
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
},
});
},
});
const createBlendSlice: Slice<BlendSlice> = (set) => ({
blend: {
mask: null,
@ -501,6 +537,7 @@ export function createStateSlices(server: ServerParams) {
createOutpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
createHighresSlice,
createBlendSlice,
createResetSlice,
};