feat(gui): add highres parameters
This commit is contained in:
parent
f462d80cc4
commit
ba09748e94
|
@ -16,6 +16,7 @@ from .diffusers.run import (
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from .diffusers.stub_scheduler import StubScheduler
|
from .diffusers.stub_scheduler import StubScheduler
|
||||||
|
from .diffusers.upscale import run_upscale_correction
|
||||||
from .image import (
|
from .image import (
|
||||||
expand_image,
|
expand_image,
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
|
@ -48,7 +49,6 @@ from .server import (
|
||||||
apply_patch_facexlib,
|
apply_patch_facexlib,
|
||||||
apply_patches,
|
apply_patches,
|
||||||
)
|
)
|
||||||
from .upscale import run_upscale_correction
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
base_join,
|
base_join,
|
||||||
get_and_clamp_float,
|
get_and_clamp_float,
|
||||||
|
|
|
@ -11,12 +11,12 @@ from onnx_web.chain.utils import process_tile_order
|
||||||
from ..chain import blend_mask, upscale_outpaint
|
from ..chain import blend_mask, upscale_outpaint
|
||||||
from ..chain.base import ChainProgress
|
from ..chain.base import ChainProgress
|
||||||
from ..output import save_image, save_params
|
from ..output import save_image, save_params
|
||||||
from ..params import Border, ImageParams, Size, StageParams, TileOrder, UpscaleParams
|
from ..params import Border, HighresParams, ImageParams, Size, StageParams, TileOrder, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..upscale import run_upscale_correction
|
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .load import get_latents_from_seed, load_pipeline
|
from .load import get_latents_from_seed, load_pipeline
|
||||||
|
from .upscale import run_upscale_correction
|
||||||
from .utils import get_inversions_from_prompt, get_loras_from_prompt
|
from .utils import get_inversions_from_prompt, get_loras_from_prompt
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -29,13 +29,8 @@ def run_txt2img_pipeline(
|
||||||
size: Size,
|
size: Size,
|
||||||
outputs: List[str],
|
outputs: List[str],
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
|
highres: HighresParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: add to params
|
|
||||||
highres_scale = 4
|
|
||||||
highres_steps = 25
|
|
||||||
highres_strength = 0.2
|
|
||||||
highres_steps_post = int((params.steps - highres_steps) / highres_strength)
|
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
|
|
||||||
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||||
|
@ -66,7 +61,7 @@ def run_txt2img_pipeline(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_images_per_prompt=params.batch,
|
num_images_per_prompt=params.batch,
|
||||||
num_inference_steps=highres_steps,
|
num_inference_steps=params.steps,
|
||||||
eta=params.eta,
|
eta=params.eta,
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
@ -81,13 +76,13 @@ def run_txt2img_pipeline(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_images_per_prompt=params.batch,
|
num_images_per_prompt=params.batch,
|
||||||
num_inference_steps=highres_steps,
|
num_inference_steps=params.steps,
|
||||||
eta=params.eta,
|
eta=params.eta,
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
for image, output in zip(result.images, outputs):
|
for image, output in zip(result.images, outputs):
|
||||||
if highres_scale > 1:
|
if highres.scale > 1:
|
||||||
highres_progress = ChainProgress.from_progress(progress)
|
highres_progress = ChainProgress.from_progress(progress)
|
||||||
|
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
|
@ -115,7 +110,7 @@ def run_txt2img_pipeline(
|
||||||
def highres(tile: Image.Image, dims):
|
def highres(tile: Image.Image, dims):
|
||||||
tile = tile.resize((size.height, size.width))
|
tile = tile.resize((size.height, size.width))
|
||||||
if params.lpw:
|
if params.lpw:
|
||||||
logger.debug("using LPW pipeline for img2img")
|
logger.debug("using LPW pipeline for highres")
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
result = highres_pipe.img2img(
|
result = highres_pipe.img2img(
|
||||||
tile,
|
tile,
|
||||||
|
@ -124,8 +119,8 @@ def run_txt2img_pipeline(
|
||||||
guidance_scale=params.cfg,
|
guidance_scale=params.cfg,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_images_per_prompt=1,
|
num_images_per_prompt=1,
|
||||||
num_inference_steps=highres_steps_post,
|
num_inference_steps=highres.steps,
|
||||||
strength=highres_strength,
|
strength=highres.strength,
|
||||||
eta=params.eta,
|
eta=params.eta,
|
||||||
callback=highres_progress,
|
callback=highres_progress,
|
||||||
)
|
)
|
||||||
|
@ -139,19 +134,19 @@ def run_txt2img_pipeline(
|
||||||
guidance_scale=params.cfg,
|
guidance_scale=params.cfg,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_images_per_prompt=1,
|
num_images_per_prompt=1,
|
||||||
num_inference_steps=highres_steps_post,
|
num_inference_steps=highres.steps,
|
||||||
strength=highres_strength,
|
strength=highres.strength,
|
||||||
eta=params.eta,
|
eta=params.eta,
|
||||||
callback=highres_progress,
|
callback=highres_progress,
|
||||||
)
|
)
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
||||||
logger.info("running highres fix for %s tiles", highres_scale)
|
logger.info("running highres fix for %s tiles", highres.scale)
|
||||||
image = process_tile_order(
|
image = process_tile_order(
|
||||||
TileOrder.grid,
|
TileOrder.grid,
|
||||||
image,
|
image,
|
||||||
size.height // highres_scale,
|
size.height // highres.scale,
|
||||||
highres_scale,
|
highres.scale,
|
||||||
[highres],
|
[highres],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -166,7 +161,7 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||||
|
|
||||||
run_gc([job.get_device()])
|
run_gc([job.get_device()])
|
||||||
|
|
||||||
|
|
|
@ -3,16 +3,16 @@ from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .chain import (
|
from ..chain import (
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
correct_codeformer,
|
correct_codeformer,
|
||||||
correct_gfpgan,
|
correct_gfpgan,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
from .server import ServerContext
|
from ..server import ServerContext
|
||||||
from .worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, List, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
|
||||||
from .server import ServerContext
|
from .server import ServerContext
|
||||||
from .utils import base_join
|
from .utils import base_join
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ def json_params(
|
||||||
size: Size,
|
size: Size,
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
|
highres: Optional[HighresParams] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
json = {
|
json = {
|
||||||
"outputs": outputs,
|
"outputs": outputs,
|
||||||
|
@ -49,6 +50,10 @@ def json_params(
|
||||||
json["border"] = border.tojson()
|
json["border"] = border.tojson()
|
||||||
size = size.add_border(border)
|
size = size.add_border(border)
|
||||||
|
|
||||||
|
if highres is not None:
|
||||||
|
json["highres"] = highres.tojson()
|
||||||
|
size = highres.resize(size)
|
||||||
|
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
json["upscale"] = upscale.tojson()
|
json["upscale"] = upscale.tojson()
|
||||||
size = upscale.resize(size)
|
size = upscale.resize(size)
|
||||||
|
@ -106,9 +111,10 @@ def save_params(
|
||||||
size: Size,
|
size: Size,
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
|
highres: Optional[HighresParams] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
path = base_join(ctx.output_path, f"{output}.json")
|
path = base_join(ctx.output_path, f"{output}.json")
|
||||||
json = json_params(output, params, size, upscale=upscale, border=border)
|
json = json_params(output, params, size, upscale=upscale, border=border, highres=highres)
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(dumps(json))
|
f.write(dumps(json))
|
||||||
logger.debug("saved image params to: %s", path)
|
logger.debug("saved image params to: %s", path)
|
||||||
|
|
|
@ -317,3 +317,25 @@ class UpscaleParams:
|
||||||
kwargs.get("tile_pad", self.tile_pad),
|
kwargs.get("tile_pad", self.tile_pad),
|
||||||
kwargs.get("upscale_order", self.upscale_order),
|
kwargs.get("upscale_order", self.upscale_order),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HighresParams:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scale: int,
|
||||||
|
steps: int,
|
||||||
|
strength: float,
|
||||||
|
):
|
||||||
|
self.scale = scale
|
||||||
|
self.steps = steps
|
||||||
|
self.strength = strength
|
||||||
|
|
||||||
|
def resize(self, size: Size) -> Size:
|
||||||
|
return Size(size.width * self.scale, size.height * self.scale)
|
||||||
|
|
||||||
|
def tojson(self):
|
||||||
|
return {
|
||||||
|
"scale": self.scale,
|
||||||
|
"steps": self.steps,
|
||||||
|
"strength": self.strength,
|
||||||
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ from .load import (
|
||||||
get_noise_sources,
|
get_noise_sources,
|
||||||
get_upscaling_models,
|
get_upscaling_models,
|
||||||
)
|
)
|
||||||
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
from .params import border_from_request, highres_from_request, pipeline_from_request, upscale_from_request
|
||||||
from .utils import wrap_route
|
from .utils import wrap_route
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -174,6 +174,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(context)
|
device, params, size = pipeline_from_request(context)
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
|
highres = highres_from_request()
|
||||||
|
|
||||||
output = make_output_name(context, "txt2img", params, size)
|
output = make_output_name(context, "txt2img", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
|
@ -187,10 +188,11 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
size,
|
size,
|
||||||
output,
|
output,
|
||||||
upscale,
|
upscale,
|
||||||
|
highres,
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||||
|
|
||||||
|
|
||||||
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from ..diffusers.load import pipeline_schedulers
|
from ..diffusers.load import pipeline_schedulers
|
||||||
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
|
from ..params import Border, DeviceParams, ImageParams, HighresParams, Size, UpscaleParams
|
||||||
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
|
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
from .load import (
|
from .load import (
|
||||||
|
@ -169,3 +169,15 @@ def upscale_from_request() -> UpscaleParams:
|
||||||
scale=scale,
|
scale=scale,
|
||||||
upscale_order=upscale_order,
|
upscale_order=upscale_order,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def highres_from_request() -> HighresParams:
|
||||||
|
scale = get_and_clamp_int(request.args, "highresScale", 1, 4, 1)
|
||||||
|
steps = get_and_clamp_int(request.args, "highresSteps", 1, 4, 1)
|
||||||
|
strength = get_and_clamp_float(request.args, "highresStrength", 0.5, 1.0, 0.0)
|
||||||
|
|
||||||
|
return HighresParams(
|
||||||
|
scale,
|
||||||
|
steps,
|
||||||
|
strength,
|
||||||
|
)
|
||||||
|
|
|
@ -144,6 +144,14 @@ export interface BlendParams {
|
||||||
mask: Blob;
|
mask: Blob;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface HighresParams {
|
||||||
|
enabled: boolean;
|
||||||
|
|
||||||
|
highresScale: number;
|
||||||
|
highresSteps: number;
|
||||||
|
highresStrength: number;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Output image data within the response.
|
* Output image data within the response.
|
||||||
*/
|
*/
|
||||||
|
@ -201,6 +209,7 @@ export type RetryParams = {
|
||||||
model: ModelParams;
|
model: ModelParams;
|
||||||
params: Txt2ImgParams;
|
params: Txt2ImgParams;
|
||||||
upscale?: UpscaleParams;
|
upscale?: UpscaleParams;
|
||||||
|
highres?: HighresParams;
|
||||||
} | {
|
} | {
|
||||||
type: 'img2img';
|
type: 'img2img';
|
||||||
model: ModelParams;
|
model: ModelParams;
|
||||||
|
@ -274,7 +283,7 @@ export interface ApiClient {
|
||||||
/**
|
/**
|
||||||
* Start a txt2img pipeline.
|
* Start a txt2img pipeline.
|
||||||
*/
|
*/
|
||||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
|
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an im2img pipeline.
|
* Start an im2img pipeline.
|
||||||
|
@ -477,7 +486,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
|
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
||||||
const url = makeImageURL(root, 'txt2img', params);
|
const url = makeImageURL(root, 'txt2img', params);
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
@ -493,6 +502,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
appendUpscaleToURL(url, upscale);
|
appendUpscaleToURL(url, upscale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (doesExist(highres) && highres.enabled) {
|
||||||
|
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
|
||||||
|
url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER));
|
||||||
|
url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT));
|
||||||
|
}
|
||||||
|
|
||||||
const image = await parseRequest(url, {
|
const image = await parseRequest(url, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
|
import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
|
||||||
|
import * as React from 'react';
|
||||||
|
import { useContext } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
|
import { ConfigContext, StateContext } from '../../state.js';
|
||||||
|
import { NumericField } from '../input/NumericField.js';
|
||||||
|
|
||||||
|
export function UpscaleControl() {
|
||||||
|
const { params } = mustExist(useContext(ConfigContext));
|
||||||
|
const state = mustExist(useContext(StateContext));
|
||||||
|
const highres = useStore(state, (s) => s.highres);
|
||||||
|
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||||
|
const setHighres = useStore(state, (s) => s.setHighres);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
return <Stack direction='row' spacing={4}>
|
||||||
|
<FormControlLabel
|
||||||
|
label={t('parameter.highres.label')}
|
||||||
|
control={<Checkbox
|
||||||
|
checked={highres.enabled}
|
||||||
|
value='check'
|
||||||
|
onChange={(event) => {
|
||||||
|
setHighres({
|
||||||
|
enabled: highres.enabled === false,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>}
|
||||||
|
/>
|
||||||
|
<NumericField
|
||||||
|
label={t('parameter.highres.steps')}
|
||||||
|
decimal
|
||||||
|
disabled={highres.enabled === false}
|
||||||
|
min={params.denoise.min}
|
||||||
|
max={params.denoise.max}
|
||||||
|
step={params.denoise.step}
|
||||||
|
value={highres.highresSteps}
|
||||||
|
onChange={(steps) => {
|
||||||
|
setHighres({
|
||||||
|
highresSteps: steps,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<NumericField
|
||||||
|
label={t('parameter.highres.scale')}
|
||||||
|
disabled={highres.enabled === false}
|
||||||
|
min={params.scale.min}
|
||||||
|
max={params.scale.max}
|
||||||
|
step={params.scale.step}
|
||||||
|
value={highres.highresScale}
|
||||||
|
onChange={(scale) => {
|
||||||
|
setHighres({
|
||||||
|
highresScale: scale,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<NumericField
|
||||||
|
label={t('parameter.highres.strength')}
|
||||||
|
disabled={highres.enabled === false}
|
||||||
|
min={params.strength.min}
|
||||||
|
max={params.strength.max}
|
||||||
|
step={params.outscale.step}
|
||||||
|
value={highres.highresStrength}
|
||||||
|
onChange={(strength) => {
|
||||||
|
setHighres({
|
||||||
|
highresStrength: strength,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Stack>;
|
||||||
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
import { doesExist, Maybe } from '@apextoaster/js-utils';
|
import { doesExist, Maybe } from '@apextoaster/js-utils';
|
||||||
import { merge } from 'lodash';
|
import { merge } from 'lodash';
|
||||||
import { Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client/api.js';
|
import { HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client/api.js';
|
||||||
|
|
||||||
export interface ConfigNumber {
|
export interface ConfigNumber {
|
||||||
default: number;
|
default: number;
|
||||||
|
@ -54,7 +54,8 @@ export type ServerParams = ConfigRanges<Required<
|
||||||
InpaintParams &
|
InpaintParams &
|
||||||
ModelParams &
|
ModelParams &
|
||||||
OutpaintParams &
|
OutpaintParams &
|
||||||
UpscaleParams
|
UpscaleParams &
|
||||||
|
HighresParams
|
||||||
>> & {
|
>> & {
|
||||||
version: string;
|
version: string;
|
||||||
};
|
};
|
||||||
|
|
|
@ -55,6 +55,7 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
|
||||||
createOutpaintSlice,
|
createOutpaintSlice,
|
||||||
createTxt2ImgSlice,
|
createTxt2ImgSlice,
|
||||||
createUpscaleSlice,
|
createUpscaleSlice,
|
||||||
|
createHighresSlice,
|
||||||
createBlendSlice,
|
createBlendSlice,
|
||||||
createResetSlice,
|
createResetSlice,
|
||||||
} = createStateSlices(params);
|
} = createStateSlices(params);
|
||||||
|
@ -68,6 +69,7 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
|
||||||
...createTxt2ImgSlice(...slice),
|
...createTxt2ImgSlice(...slice),
|
||||||
...createOutpaintSlice(...slice),
|
...createOutpaintSlice(...slice),
|
||||||
...createUpscaleSlice(...slice),
|
...createUpscaleSlice(...slice),
|
||||||
|
...createHighresSlice(...slice),
|
||||||
...createBlendSlice(...slice),
|
...createBlendSlice(...slice),
|
||||||
...createResetSlice(...slice),
|
...createResetSlice(...slice),
|
||||||
}), {
|
}), {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import {
|
||||||
BaseImgParams,
|
BaseImgParams,
|
||||||
BlendParams,
|
BlendParams,
|
||||||
BrushParams,
|
BrushParams,
|
||||||
|
HighresParams,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
Img2ImgParams,
|
Img2ImgParams,
|
||||||
InpaintParams,
|
InpaintParams,
|
||||||
|
@ -90,6 +91,13 @@ interface OutpaintSlice {
|
||||||
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface HighresSlice {
|
||||||
|
highres: HighresParams;
|
||||||
|
|
||||||
|
setHighres(params: Partial<HighresParams>): void;
|
||||||
|
resetHighres(): void;
|
||||||
|
}
|
||||||
|
|
||||||
interface UpscaleSlice {
|
interface UpscaleSlice {
|
||||||
upscale: UpscaleParams;
|
upscale: UpscaleParams;
|
||||||
upscaleTab: TabState<UpscaleReqParams>;
|
upscaleTab: TabState<UpscaleReqParams>;
|
||||||
|
@ -123,6 +131,7 @@ export type OnnxState
|
||||||
& ModelSlice
|
& ModelSlice
|
||||||
& OutpaintSlice
|
& OutpaintSlice
|
||||||
& Txt2ImgSlice
|
& Txt2ImgSlice
|
||||||
|
& HighresSlice
|
||||||
& UpscaleSlice
|
& UpscaleSlice
|
||||||
& BlendSlice
|
& BlendSlice
|
||||||
& ResetSlice;
|
& ResetSlice;
|
||||||
|
@ -421,6 +430,33 @@ export function createStateSlices(server: ServerParams) {
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const createHighresSlice: Slice<HighresSlice> = (set) => ({
|
||||||
|
highres: {
|
||||||
|
enabled: false,
|
||||||
|
highresSteps: server.highresSteps.default,
|
||||||
|
highresScale: server.highresScale.default,
|
||||||
|
highresStrength: server.highresStrength.default,
|
||||||
|
},
|
||||||
|
setHighres(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
highres: {
|
||||||
|
...prev.highres,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
resetHighres() {
|
||||||
|
set({
|
||||||
|
highres: {
|
||||||
|
enabled: false,
|
||||||
|
highresSteps: server.highresSteps.default,
|
||||||
|
highresScale: server.highresScale.default,
|
||||||
|
highresStrength: server.highresStrength.default,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
const createBlendSlice: Slice<BlendSlice> = (set) => ({
|
const createBlendSlice: Slice<BlendSlice> = (set) => ({
|
||||||
blend: {
|
blend: {
|
||||||
mask: null,
|
mask: null,
|
||||||
|
@ -501,6 +537,7 @@ export function createStateSlices(server: ServerParams) {
|
||||||
createOutpaintSlice,
|
createOutpaintSlice,
|
||||||
createTxt2ImgSlice,
|
createTxt2ImgSlice,
|
||||||
createUpscaleSlice,
|
createUpscaleSlice,
|
||||||
|
createHighresSlice,
|
||||||
createBlendSlice,
|
createBlendSlice,
|
||||||
createResetSlice,
|
createResetSlice,
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue