1
0
Fork 0

feat(gui): add highres control to most tabs

This commit is contained in:
Sean Sube 2023-04-14 20:29:44 -05:00
parent ad35c41c9d
commit 27954f3e65
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 172 additions and 74 deletions

View File

@ -42,6 +42,9 @@ def run_highres(
inversions: List[Tuple[str, float]],
loras: List[Tuple[str, float]],
) -> None:
if highres.scale <= 1:
return image
highres_progress = ChainProgress.from_progress(progress)
if upscale.faces and (
@ -213,7 +216,6 @@ def run_txt2img_pipeline(
del pipe
for image, output in image_outputs:
if highres.scale > 1:
image = run_highres(
job,
server,
@ -321,7 +323,6 @@ def run_img2img_pipeline(
images.append(source)
for image, output in zip(images, outputs):
if highres.scale > 1:
image = run_highres(
job,
server,
@ -373,6 +374,10 @@ def run_inpaint_pipeline(
progress = job.get_progress_callback()
stage = StageParams(tile_order=tile_order)
(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
# calling the upscale_outpaint stage directly needs accumulating progress
progress = ChainProgress.from_progress(progress)
@ -391,6 +396,19 @@ def run_inpaint_pipeline(
callback=progress,
)
image = run_highres(
job,
server,
params,
size,
upscale,
highres,
image,
progress,
inversions,
loras,
)
image = run_upscale_correction(
job,
server,
@ -424,10 +442,28 @@ def run_upscale_pipeline(
progress = job.get_progress_callback()
stage = StageParams()
(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
image = run_upscale_correction(
job, server, stage, params, source, upscale=upscale, callback=progress
)
# TODO: should this come first?
image = run_highres(
job,
server,
params,
size,
upscale,
highres,
image,
progress,
inversions,
loras,
)
dest = save_image(server, outputs[0], image)
save_params(server, outputs[0], params, size, upscale=upscale)
@ -463,6 +499,20 @@ def run_blend_pipeline(
)
image = image.convert("RGB")
# TODO: blend tab doesn't have a prompt
image = run_highres(
job,
server,
params,
size,
upscale,
highres,
image,
progress,
[],
[],
)
image = run_upscale_correction(
job, server, stage, params, image, upscale=upscale, callback=progress
)

View File

@ -164,6 +164,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request()
highres = highres_from_request()
source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys())
)
@ -195,13 +196,14 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
params,
output,
upscale,
highres,
source,
strength,
needs_device=device,
source_filter=source_filter,
)
return jsonify(json_params(output, params, size, upscale=upscale))
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
@ -243,6 +245,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "inpaint")
expand = border_from_request()
upscale = upscale_from_request()
highres = highres_from_request()
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
@ -280,6 +283,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
size,
output,
upscale,
highres,
source,
mask,
expand,
@ -290,7 +294,11 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
return jsonify(
json_params(
output, params, size, upscale=upscale, border=expand, highres=highres
)
)
def upscale(server: ServerContext, pool: DevicePoolExecutor):
@ -302,6 +310,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
highres = highres_from_request()
output = make_output_name(server, "upscale", params, size)
job_name = output[0]
@ -316,11 +325,12 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
size,
output,
upscale,
highres,
source,
needs_device=device,
)
return jsonify(json_params(output, params, size, upscale=upscale))
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
def chain(server: ServerContext, pool: DevicePoolExecutor):

View File

@ -229,21 +229,25 @@ export type RetryParams = {
model: ModelParams;
params: Img2ImgParams;
upscale?: UpscaleParams;
highres?: HighresParams;
} | {
type: 'inpaint';
model: ModelParams;
params: InpaintParams;
upscale?: UpscaleParams;
highres?: HighresParams;
} | {
type: 'outpaint';
model: ModelParams;
params: OutpaintParams;
upscale?: UpscaleParams;
highres?: HighresParams;
} | {
type: 'upscale';
model: ModelParams;
params: UpscaleReqParams;
upscale?: UpscaleParams;
highres?: HighresParams;
} | {
type: 'blend';
model: ModelParams;
@ -307,22 +311,22 @@ export interface ApiClient {
/**
* Start an im2img pipeline.
*/
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
/**
* Start an inpaint pipeline.
*/
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
/**
* Start an outpaint pipeline.
*/
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
/**
* Start an upscale pipeline.
*/
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
/**
* Start a blending pipeline.
@ -431,6 +435,16 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
}
}
export function appendHighresToURL(url: URL, highres: HighresParams) {
if (highres.enabled) {
url.searchParams.append('highresIterations', highres.highresIterations.toFixed(FIXED_INTEGER));
url.searchParams.append('highresMethod', highres.highresMethod);
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER));
url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT));
}
}
/**
* Make an API client using the given API root and fetch client.
*/
@ -484,7 +498,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
translation: Record<string, string>;
}>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);
@ -498,6 +512,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}
if (doesExist(highres)) {
appendHighresToURL(url, highres);
}
const body = new FormData();
body.append('source', params.source, 'source');
@ -531,12 +549,8 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}
if (doesExist(highres) && highres.enabled) {
url.searchParams.append('highresIterations', highres.highresIterations.toFixed(FIXED_INTEGER));
url.searchParams.append('highresMethod', highres.highresMethod);
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER));
url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT));
if (doesExist(highres)) {
appendHighresToURL(url, highres);
}
const image = await parseRequest(url, {
@ -553,7 +567,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
},
};
},
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
@ -566,6 +580,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}
if (doesExist(highres)) {
appendHighresToURL(url, highres);
}
const body = new FormData();
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');
@ -584,7 +602,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
},
};
},
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
@ -598,6 +616,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}
if (doesExist(highres)) {
appendHighresToURL(url, highres);
}
if (doesExist(params.left)) {
url.searchParams.append('left', params.left.toFixed(FIXED_INTEGER));
}
@ -632,7 +654,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
},
};
},
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
const url = makeApiUrl(root, 'upscale');
appendModelToURL(url, model);
@ -640,6 +662,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}
if (doesExist(highres)) {
appendHighresToURL(url, highres);
}
url.searchParams.append('prompt', params.prompt);
if (doesExist(params.negativePrompt)) {
@ -714,15 +740,15 @@ export function makeClient(root: string, f = fetch): ApiClient {
case 'blend':
return this.blend(retry.model, retry.params, retry.upscale);
case 'img2img':
return this.img2img(retry.model, retry.params, retry.upscale);
return this.img2img(retry.model, retry.params, retry.upscale, retry.highres);
case 'inpaint':
return this.inpaint(retry.model, retry.params, retry.upscale);
return this.inpaint(retry.model, retry.params, retry.upscale, retry.highres);
case 'outpaint':
return this.outpaint(retry.model, retry.params, retry.upscale);
return this.outpaint(retry.model, retry.params, retry.upscale, retry.highres);
case 'txt2img':
return this.txt2img(retry.model, retry.params, retry.upscale, retry.highres);
case 'upscale':
return this.upscale(retry.model, retry.params, retry.upscale);
return this.upscale(retry.model, retry.params, retry.upscale, retry.highres);
default:
throw new InvalidArgumentError('unknown request type');
}

View File

@ -73,6 +73,21 @@ export function ModelControl() {
});
}}
/>
<QueryList
id='pipeline'
labelKey='pipeline'
name={t('parameter.pipeline')}
query={{
result: pipelines,
}}
showEmpty
value={params.pipeline}
onChange={(pipeline) => {
setModel({
pipeline,
});
}}
/>
<QueryList
id='diffusion'
labelKey='model'
@ -118,6 +133,8 @@ export function ModelControl() {
});
}}
/>
</Stack>
<Stack direction='row' spacing={2}>
<QueryList
id='control'
labelKey='model.control'
@ -133,23 +150,6 @@ export function ModelControl() {
});
}}
/>
</Stack>
<Stack direction='row' spacing={2}>
<QueryList
id='pipeline'
labelKey='pipeline'
name={t('parameter.pipeline')}
query={{
result: pipelines,
}}
showEmpty
value={params.pipeline}
onChange={(pipeline) => {
setModel({
pipeline,
});
}}
/>
<QueryMenu
id='inversion'
labelKey='model.inversion'

View File

@ -13,16 +13,17 @@ import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { NumericField } from '../input/NumericField.js';
import { QueryList } from '../input/QueryList.js';
import { HighresControl } from '../control/HighresControl.js';
export function Img2Img() {
const { params } = mustExist(useContext(ConfigContext));
async function uploadSource() {
const { model, img2img, upscale } = state.getState();
const { model, img2img, upscale, highres } = state.getState();
const { image, retry } = await client.img2img(model, {
...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, upscale);
}, upscale, highres);
pushHistory(image, retry);
}
@ -50,11 +51,16 @@ export function Img2Img() {
return <Box>
<Stack spacing={2}>
<ImageInput filter={IMAGE_FILTER} image={source} label={t('input.image.source')} onChange={(file) => {
<ImageInput
filter={IMAGE_FILTER}
image={source}
label={t('input.image.source')}
onChange={(file) => {
setImg2Img({
source: file,
});
}} />
}}
/>
<ImageControl selector={(s) => s.img2img} onChange={setImg2Img} />
<Stack direction='row' spacing={2}>
<QueryList
@ -86,6 +92,7 @@ export function Img2Img() {
}}
/>
</Stack>
<HighresControl />
<UpscaleControl />
<Button
disabled={doesExist(source) === false}

View File

@ -15,6 +15,7 @@ import { ImageInput } from '../input/ImageInput.js';
import { MaskCanvas } from '../input/MaskCanvas.js';
import { NumericField } from '../input/NumericField.js';
import { QueryList } from '../input/QueryList.js';
import { HighresControl } from '../control/HighresControl.js';
export function Inpaint() {
const { params } = mustExist(useContext(ConfigContext));
@ -29,7 +30,7 @@ export function Inpaint() {
async function uploadSource(): Promise<void> {
// these are not watched by the component, only sent by the mutation
const { model, inpaint, outpaint, upscale } = state.getState();
const { model, inpaint, outpaint, upscale, highres } = state.getState();
if (outpaint.enabled) {
const { image, retry } = await client.outpaint(model, {
@ -37,7 +38,7 @@ export function Inpaint() {
...outpaint,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);
}, upscale, highres);
pushHistory(image, retry);
} else {
@ -45,7 +46,7 @@ export function Inpaint() {
...inpaint,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);
}, upscale, highres);
pushHistory(image, retry);
}
@ -207,6 +208,7 @@ export function Inpaint() {
</Stack>
</Stack>
<OutpaintControl />
<HighresControl />
<UpscaleControl />
<Button
disabled={preventInpaint()}

View File

@ -11,6 +11,7 @@ import { ClientContext, StateContext } from '../../state.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { PromptInput } from '../input/PromptInput.js';
import { HighresControl } from '../control/HighresControl.js';
export function Upscale() {
async function uploadSource() {
@ -56,6 +57,7 @@ export function Upscale() {
setSource(value);
}}
/>
<HighresControl />
<UpscaleControl />
<Button
disabled={doesExist(params.source) === false}

View File

@ -33,6 +33,7 @@
"ftfy",
"gfpgan",
"Heun",
"Highres",
"huggingface",
"Inpaint",
"inpainting",