feat(gui): add highres parameters
This commit is contained in:
parent
f462d80cc4
commit
ba09748e94
|
@ -16,6 +16,7 @@ from .diffusers.run import (
|
|||
run_upscale_pipeline,
|
||||
)
|
||||
from .diffusers.stub_scheduler import StubScheduler
|
||||
from .diffusers.upscale import run_upscale_correction
|
||||
from .image import (
|
||||
expand_image,
|
||||
mask_filter_gaussian_multiply,
|
||||
|
@ -48,7 +49,6 @@ from .server import (
|
|||
apply_patch_facexlib,
|
||||
apply_patches,
|
||||
)
|
||||
from .upscale import run_upscale_correction
|
||||
from .utils import (
|
||||
base_join,
|
||||
get_and_clamp_float,
|
||||
|
|
|
@ -11,12 +11,12 @@ from onnx_web.chain.utils import process_tile_order
|
|||
from ..chain import blend_mask, upscale_outpaint
|
||||
from ..chain.base import ChainProgress
|
||||
from ..output import save_image, save_params
|
||||
from ..params import Border, ImageParams, Size, StageParams, TileOrder, UpscaleParams
|
||||
from ..params import Border, HighresParams, ImageParams, Size, StageParams, TileOrder, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..upscale import run_upscale_correction
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .load import get_latents_from_seed, load_pipeline
|
||||
from .upscale import run_upscale_correction
|
||||
from .utils import get_inversions_from_prompt, get_loras_from_prompt
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -29,13 +29,8 @@ def run_txt2img_pipeline(
|
|||
size: Size,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
highres: HighresParams,
|
||||
) -> None:
|
||||
# TODO: add to params
|
||||
highres_scale = 4
|
||||
highres_steps = 25
|
||||
highres_strength = 0.2
|
||||
highres_steps_post = int((params.steps - highres_steps) / highres_strength)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
|
||||
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||
|
@ -66,7 +61,7 @@ def run_txt2img_pipeline(
|
|||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=params.batch,
|
||||
num_inference_steps=highres_steps,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
|
@ -81,13 +76,13 @@ def run_txt2img_pipeline(
|
|||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=params.batch,
|
||||
num_inference_steps=highres_steps,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
for image, output in zip(result.images, outputs):
|
||||
if highres_scale > 1:
|
||||
if highres.scale > 1:
|
||||
highres_progress = ChainProgress.from_progress(progress)
|
||||
|
||||
image = run_upscale_correction(
|
||||
|
@ -115,7 +110,7 @@ def run_txt2img_pipeline(
|
|||
def highres(tile: Image.Image, dims):
|
||||
tile = tile.resize((size.height, size.width))
|
||||
if params.lpw:
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
logger.debug("using LPW pipeline for highres")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = highres_pipe.img2img(
|
||||
tile,
|
||||
|
@ -124,8 +119,8 @@ def run_txt2img_pipeline(
|
|||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=highres_steps_post,
|
||||
strength=highres_strength,
|
||||
num_inference_steps=highres.steps,
|
||||
strength=highres.strength,
|
||||
eta=params.eta,
|
||||
callback=highres_progress,
|
||||
)
|
||||
|
@ -139,19 +134,19 @@ def run_txt2img_pipeline(
|
|||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_images_per_prompt=1,
|
||||
num_inference_steps=highres_steps_post,
|
||||
strength=highres_strength,
|
||||
num_inference_steps=highres.steps,
|
||||
strength=highres.strength,
|
||||
eta=params.eta,
|
||||
callback=highres_progress,
|
||||
)
|
||||
return result.images[0]
|
||||
|
||||
logger.info("running highres fix for %s tiles", highres_scale)
|
||||
logger.info("running highres fix for %s tiles", highres.scale)
|
||||
image = process_tile_order(
|
||||
TileOrder.grid,
|
||||
image,
|
||||
size.height // highres_scale,
|
||||
highres_scale,
|
||||
size.height // highres.scale,
|
||||
highres.scale,
|
||||
[highres],
|
||||
)
|
||||
|
||||
|
@ -166,7 +161,7 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
save_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||
|
||||
run_gc([job.get_device()])
|
||||
|
||||
|
|
|
@ -3,16 +3,16 @@ from typing import Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from .chain import (
|
||||
from ..chain import (
|
||||
ChainPipeline,
|
||||
correct_codeformer,
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from .server import ServerContext
|
||||
from .worker import ProgressCallback, WorkerContext
|
||||
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
@ -8,7 +8,7 @@ from typing import Any, List, Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
||||
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
|
||||
from .server import ServerContext
|
||||
from .utils import base_join
|
||||
|
||||
|
@ -36,6 +36,7 @@ def json_params(
|
|||
size: Size,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
) -> Any:
|
||||
json = {
|
||||
"outputs": outputs,
|
||||
|
@ -49,6 +50,10 @@ def json_params(
|
|||
json["border"] = border.tojson()
|
||||
size = size.add_border(border)
|
||||
|
||||
if highres is not None:
|
||||
json["highres"] = highres.tojson()
|
||||
size = highres.resize(size)
|
||||
|
||||
if upscale is not None:
|
||||
json["upscale"] = upscale.tojson()
|
||||
size = upscale.resize(size)
|
||||
|
@ -106,9 +111,10 @@ def save_params(
|
|||
size: Size,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
) -> str:
|
||||
path = base_join(ctx.output_path, f"{output}.json")
|
||||
json = json_params(output, params, size, upscale=upscale, border=border)
|
||||
json = json_params(output, params, size, upscale=upscale, border=border, highres=highres)
|
||||
with open(path, "w") as f:
|
||||
f.write(dumps(json))
|
||||
logger.debug("saved image params to: %s", path)
|
||||
|
|
|
@ -317,3 +317,25 @@ class UpscaleParams:
|
|||
kwargs.get("tile_pad", self.tile_pad),
|
||||
kwargs.get("upscale_order", self.upscale_order),
|
||||
)
|
||||
|
||||
|
||||
class HighresParams:
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
steps: int,
|
||||
strength: float,
|
||||
):
|
||||
self.scale = scale
|
||||
self.steps = steps
|
||||
self.strength = strength
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
return Size(size.width * self.scale, size.height * self.scale)
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"scale": self.scale,
|
||||
"steps": self.steps,
|
||||
"strength": self.strength,
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ from .load import (
|
|||
get_noise_sources,
|
||||
get_upscaling_models,
|
||||
)
|
||||
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
||||
from .params import border_from_request, highres_from_request, pipeline_from_request, upscale_from_request
|
||||
from .utils import wrap_route
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -174,6 +174,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(context)
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
|
||||
output = make_output_name(context, "txt2img", params, size)
|
||||
job_name = output[0]
|
||||
|
@ -187,10 +188,11 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
|||
size,
|
||||
output,
|
||||
upscale,
|
||||
highres,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||
|
||||
|
||||
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
from flask import request
|
||||
|
||||
from ..diffusers.load import pipeline_schedulers
|
||||
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
|
||||
from ..params import Border, DeviceParams, ImageParams, HighresParams, Size, UpscaleParams
|
||||
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
|
||||
from .context import ServerContext
|
||||
from .load import (
|
||||
|
@ -169,3 +169,15 @@ def upscale_from_request() -> UpscaleParams:
|
|||
scale=scale,
|
||||
upscale_order=upscale_order,
|
||||
)
|
||||
|
||||
|
||||
def highres_from_request() -> HighresParams:
|
||||
scale = get_and_clamp_int(request.args, "highresScale", 1, 4, 1)
|
||||
steps = get_and_clamp_int(request.args, "highresSteps", 1, 4, 1)
|
||||
strength = get_and_clamp_float(request.args, "highresStrength", 0.5, 1.0, 0.0)
|
||||
|
||||
return HighresParams(
|
||||
scale,
|
||||
steps,
|
||||
strength,
|
||||
)
|
||||
|
|
|
@ -144,6 +144,14 @@ export interface BlendParams {
|
|||
mask: Blob;
|
||||
}
|
||||
|
||||
export interface HighresParams {
|
||||
enabled: boolean;
|
||||
|
||||
highresScale: number;
|
||||
highresSteps: number;
|
||||
highresStrength: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Output image data within the response.
|
||||
*/
|
||||
|
@ -201,6 +209,7 @@ export type RetryParams = {
|
|||
model: ModelParams;
|
||||
params: Txt2ImgParams;
|
||||
upscale?: UpscaleParams;
|
||||
highres?: HighresParams;
|
||||
} | {
|
||||
type: 'img2img';
|
||||
model: ModelParams;
|
||||
|
@ -274,7 +283,7 @@ export interface ApiClient {
|
|||
/**
|
||||
* Start a txt2img pipeline.
|
||||
*/
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an im2img pipeline.
|
||||
|
@ -477,7 +486,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
},
|
||||
};
|
||||
},
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
||||
const url = makeImageURL(root, 'txt2img', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
|
@ -493,6 +502,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
appendUpscaleToURL(url, upscale);
|
||||
}
|
||||
|
||||
if (doesExist(highres) && highres.enabled) {
|
||||
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
|
||||
url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER));
|
||||
url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT));
|
||||
}
|
||||
|
||||
const image = await parseRequest(url, {
|
||||
method: 'POST',
|
||||
});
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigContext, StateContext } from '../../state.js';
|
||||
import { NumericField } from '../input/NumericField.js';
|
||||
|
||||
export function UpscaleControl() {
|
||||
const { params } = mustExist(useContext(ConfigContext));
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const highres = useStore(state, (s) => s.highres);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setHighres = useStore(state, (s) => s.setHighres);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return <Stack direction='row' spacing={4}>
|
||||
<FormControlLabel
|
||||
label={t('parameter.highres.label')}
|
||||
control={<Checkbox
|
||||
checked={highres.enabled}
|
||||
value='check'
|
||||
onChange={(event) => {
|
||||
setHighres({
|
||||
enabled: highres.enabled === false,
|
||||
});
|
||||
}}
|
||||
/>}
|
||||
/>
|
||||
<NumericField
|
||||
label={t('parameter.highres.steps')}
|
||||
decimal
|
||||
disabled={highres.enabled === false}
|
||||
min={params.denoise.min}
|
||||
max={params.denoise.max}
|
||||
step={params.denoise.step}
|
||||
value={highres.highresSteps}
|
||||
onChange={(steps) => {
|
||||
setHighres({
|
||||
highresSteps: steps,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label={t('parameter.highres.scale')}
|
||||
disabled={highres.enabled === false}
|
||||
min={params.scale.min}
|
||||
max={params.scale.max}
|
||||
step={params.scale.step}
|
||||
value={highres.highresScale}
|
||||
onChange={(scale) => {
|
||||
setHighres({
|
||||
highresScale: scale,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label={t('parameter.highres.strength')}
|
||||
disabled={highres.enabled === false}
|
||||
min={params.strength.min}
|
||||
max={params.strength.max}
|
||||
step={params.outscale.step}
|
||||
value={highres.highresStrength}
|
||||
onChange={(strength) => {
|
||||
setHighres({
|
||||
highresStrength: strength,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</Stack>;
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
import { doesExist, Maybe } from '@apextoaster/js-utils';
|
||||
import { merge } from 'lodash';
|
||||
import { Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client/api.js';
|
||||
import { HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client/api.js';
|
||||
|
||||
export interface ConfigNumber {
|
||||
default: number;
|
||||
|
@ -54,7 +54,8 @@ export type ServerParams = ConfigRanges<Required<
|
|||
InpaintParams &
|
||||
ModelParams &
|
||||
OutpaintParams &
|
||||
UpscaleParams
|
||||
UpscaleParams &
|
||||
HighresParams
|
||||
>> & {
|
||||
version: string;
|
||||
};
|
||||
|
|
|
@ -55,6 +55,7 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
|
|||
createOutpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
createUpscaleSlice,
|
||||
createHighresSlice,
|
||||
createBlendSlice,
|
||||
createResetSlice,
|
||||
} = createStateSlices(params);
|
||||
|
@ -68,6 +69,7 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
|
|||
...createTxt2ImgSlice(...slice),
|
||||
...createOutpaintSlice(...slice),
|
||||
...createUpscaleSlice(...slice),
|
||||
...createHighresSlice(...slice),
|
||||
...createBlendSlice(...slice),
|
||||
...createResetSlice(...slice),
|
||||
}), {
|
||||
|
|
|
@ -10,6 +10,7 @@ import {
|
|||
BaseImgParams,
|
||||
BlendParams,
|
||||
BrushParams,
|
||||
HighresParams,
|
||||
ImageResponse,
|
||||
Img2ImgParams,
|
||||
InpaintParams,
|
||||
|
@ -90,6 +91,13 @@ interface OutpaintSlice {
|
|||
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
||||
}
|
||||
|
||||
interface HighresSlice {
|
||||
highres: HighresParams;
|
||||
|
||||
setHighres(params: Partial<HighresParams>): void;
|
||||
resetHighres(): void;
|
||||
}
|
||||
|
||||
interface UpscaleSlice {
|
||||
upscale: UpscaleParams;
|
||||
upscaleTab: TabState<UpscaleReqParams>;
|
||||
|
@ -123,6 +131,7 @@ export type OnnxState
|
|||
& ModelSlice
|
||||
& OutpaintSlice
|
||||
& Txt2ImgSlice
|
||||
& HighresSlice
|
||||
& UpscaleSlice
|
||||
& BlendSlice
|
||||
& ResetSlice;
|
||||
|
@ -421,6 +430,33 @@ export function createStateSlices(server: ServerParams) {
|
|||
},
|
||||
});
|
||||
|
||||
const createHighresSlice: Slice<HighresSlice> = (set) => ({
|
||||
highres: {
|
||||
enabled: false,
|
||||
highresSteps: server.highresSteps.default,
|
||||
highresScale: server.highresScale.default,
|
||||
highresStrength: server.highresStrength.default,
|
||||
},
|
||||
setHighres(params) {
|
||||
set((prev) => ({
|
||||
highres: {
|
||||
...prev.highres,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetHighres() {
|
||||
set({
|
||||
highres: {
|
||||
enabled: false,
|
||||
highresSteps: server.highresSteps.default,
|
||||
highresScale: server.highresScale.default,
|
||||
highresStrength: server.highresStrength.default,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createBlendSlice: Slice<BlendSlice> = (set) => ({
|
||||
blend: {
|
||||
mask: null,
|
||||
|
@ -501,6 +537,7 @@ export function createStateSlices(server: ServerParams) {
|
|||
createOutpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
createUpscaleSlice,
|
||||
createHighresSlice,
|
||||
createBlendSlice,
|
||||
createResetSlice,
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue