1
0
Fork 0

feat: add upscale controls to client, params to server

This commit is contained in:
Sean Sube 2023-01-16 14:05:04 -06:00
parent 1f0c19af04
commit d1e4fa9cf1
9 changed files with 188 additions and 25 deletions

View File

@ -5,6 +5,12 @@
"max": 30,
"step": 0.1
},
"denoise": {
"default": 0.5,
"min": 0,
"max": 0,
"step": 0.1
},
"height": {
"default": 512,
"min": 64,
@ -27,6 +33,12 @@
"default": "an astronaut eating a hamburger",
"keys": []
},
"scale": {
"default": 1,
"min": 1,
"max": 4,
"step": 1
},
"scheduler": {
"default": "euler-a",
"keys": []

View File

@ -65,6 +65,14 @@ export interface BrushParams {
strength: number;
}
export interface UpscaleParams {
enabled: boolean;
denoise: number;
faces: boolean;
scale: number;
}
export interface ApiResponse {
output: {
key: string;
@ -112,6 +120,9 @@ export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams
};
}
export const FIXED_INTEGER = 0;
export const FIXED_FLOAT = 2;
export function equalResponse(a: ApiResponse, b: ApiResponse): boolean {
return a.output === b.output;
}
@ -126,8 +137,8 @@ export function makeApiUrl(root: string, ...path: Array<string>) {
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
const url = makeApiUrl(root, type);
url.searchParams.append('cfg', params.cfg.toFixed(1));
url.searchParams.append('steps', params.steps.toFixed(0));
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
@ -142,7 +153,7 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
}
if (doesExist(params.seed)) {
url.searchParams.append('seed', params.seed.toFixed(0));
url.searchParams.append('seed', params.seed.toFixed(FIXED_INTEGER));
}
// put prompt last, in case a load balancer decides to truncate the URL
@ -155,6 +166,12 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
return url;
}
export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
url.searchParams.append('faces', String(upscale.faces));
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
}
export function makeClient(root: string, f = fetch): ApiClient {
let pending: Promise<ApiResponse> | undefined;
@ -195,13 +212,17 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async img2img(params: Img2ImgParams): Promise<ApiResponse> {
async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) {
return pending;
}
const url = makeImageURL(root, 'img2img', params);
url.searchParams.append('strength', params.strength.toFixed(2));
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
const body = new FormData();
body.append('source', params.source, 'source');
@ -214,7 +235,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async txt2img(params: Txt2ImgParams): Promise<ApiResponse> {
async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) {
return pending;
}
@ -222,11 +243,15 @@ export function makeClient(root: string, f = fetch): ApiClient {
const url = makeImageURL(root, 'txt2img', params);
if (doesExist(params.width)) {
url.searchParams.append('width', params.width.toFixed(0));
url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER));
}
if (doesExist(params.height)) {
url.searchParams.append('height', params.height.toFixed(0));
url.searchParams.append('height', params.height.toFixed(FIXED_INTEGER));
}
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
pending = throttleRequest(url, {
@ -236,7 +261,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async inpaint(params: InpaintParams) {
async inpaint(params: InpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) {
return pending;
}
@ -244,6 +269,9 @@ export function makeClient(root: string, f = fetch): ApiClient {
const url = makeImageURL(root, 'inpaint', params);
url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise);
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
const body = new FormData();
body.append('mask', params.mask, 'mask');
@ -257,7 +285,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async outpaint(params: OutpaintParams) {
async outpaint(params: OutpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) {
return pending;
}
@ -266,20 +294,24 @@ export function makeClient(root: string, f = fetch): ApiClient {
url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise);
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
if (doesExist(params.left)) {
url.searchParams.append('left', params.left.toFixed(0));
url.searchParams.append('left', params.left.toFixed(FIXED_INTEGER));
}
if (doesExist(params.right)) {
url.searchParams.append('right', params.right.toFixed(0));
url.searchParams.append('right', params.right.toFixed(FIXED_INTEGER));
}
if (doesExist(params.top)) {
url.searchParams.append('top', params.top.toFixed(0));
url.searchParams.append('top', params.top.toFixed(FIXED_INTEGER));
}
if (doesExist(params.bottom)) {
url.searchParams.append('bottom', params.bottom.toFixed(0));
url.searchParams.append('bottom', params.bottom.toFixed(FIXED_INTEGER));
}
const body = new FormData();

View File

@ -9,6 +9,7 @@ import { ClientContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
import { NumericField } from './NumericField.js';
import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React;
@ -69,6 +70,7 @@ export function Img2Img(props: Img2ImgProps) {
});
}}
/>
<UpscaleControl config={config} />
<Button onClick={() => upload.mutate()}>Generate</Button>
</Stack>
</Box>;

View File

@ -12,6 +12,7 @@ import { ImageInput } from './ImageInput.js';
import { MaskCanvas } from './MaskCanvas.js';
import { OutpaintControl } from './OutpaintControl.js';
import { QueryList } from './QueryList.js';
import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React;
@ -139,6 +140,7 @@ export function Inpaint(props: InpaintProps) {
/>
</Stack>
<OutpaintControl config={config} />
<UpscaleControl config={config} />
<Button onClick={() => upload.mutate()}>Generate</Button>
</Stack>
</Box>;

View File

@ -8,6 +8,7 @@ import { ConfigParams } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js';
import { NumericField } from './NumericField.js';
import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React;
@ -75,6 +76,7 @@ export function Txt2Img(props: Txt2ImgProps) {
}}
/>
</Stack>
<UpscaleControl config={config} />
<Button onClick={() => generate.mutate()}>Generate</Button>
</Stack>
</Box>;

View File

@ -0,0 +1,78 @@
import { mustExist } from '@apextoaster/js-utils';
import { Check, FaceRetouchingNatural, ZoomIn } from '@mui/icons-material';
import { Stack, ToggleButton } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useStore } from 'zustand';
import { ConfigParams } from '../config.js';
import { StateContext } from '../state.js';
import { NumericField } from './NumericField.js';
export interface UpscaleControlProps {
config: ConfigParams;
}
export function UpscaleControl(props: UpscaleControlProps) {
const { config } = props;
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.upscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setUpscale);
return <Stack direction='row' spacing={4}>
<ToggleButton
color='primary'
selected={params.enabled}
value='check'
onChange={(event) => {
setUpscale({
enabled: params.enabled === false,
});
}}
>
<ZoomIn />
Upscale
</ToggleButton>
<NumericField
label='Scale'
disabled={params.enabled === false}
min={config.scale.min}
max={config.scale.max}
step={config.scale.step}
value={params.scale}
onChange={(scale) => {
setUpscale({
scale,
});
}}
/>
<NumericField
label='Denoise'
disabled={params.enabled === false}
min={config.denoise.min}
max={config.denoise.max}
step={config.denoise.step}
value={params.denoise}
onChange={(denoise) => {
setUpscale({
denoise,
});
}}
/>
<ToggleButton
color='primary'
selected={params.enabled}
value='check'
onChange={(event) => {
setUpscale({
faces: params.faces === false,
});
}}
>
<FaceRetouchingNatural />
Face Correction
</ToggleButton>
</Stack>;
}

View File

@ -1,6 +1,6 @@
import { Maybe } from '@apextoaster/js-utils';
import { Img2ImgParams, STATUS_SUCCESS, Txt2ImgParams } from './api/client.js';
import { Img2ImgParams, InpaintParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './api/client.js';
export interface ConfigNumber {
default: number;
@ -30,7 +30,7 @@ export type ConfigState<T extends object, TValid = number | string> = {
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
};
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams>>;
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams & InpaintParams & OutpaintParams & UpscaleParams>>;
export interface Config {
api: {

View File

@ -43,22 +43,24 @@ export async function main() {
// prep zustand with a slice for each tab, using local storage
const {
createBrushSlice,
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
createBrushSlice,
createOutpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
} = createStateSlices(params);
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
...createTxt2ImgSlice(...slice),
...createBrushSlice(...slice),
...createDefaultSlice(...slice),
...createHistorySlice(...slice),
...createImg2ImgSlice(...slice),
...createInpaintSlice(...slice),
...createHistorySlice(...slice),
...createDefaultSlice(...slice),
...createBrushSlice(...slice),
...createTxt2ImgSlice(...slice),
...createOutpaintSlice(...slice),
...createUpscaleSlice(...slice),
}), {
name: 'onnx-web',
partialize(s) {

View File

@ -13,6 +13,7 @@ import {
OutpaintPixels,
paramsFromConfig,
Txt2ImgParams,
UpscaleParams,
} from './api/client.js';
import { ConfigFiles, ConfigParams, ConfigState } from './config.js';
@ -68,7 +69,21 @@ interface BrushSlice {
setBrush(brush: Partial<BrushParams>): void;
}
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice & OutpaintSlice & BrushSlice;
interface UpscaleSlice {
upscale: UpscaleParams;
setUpscale(upscale: Partial<UpscaleParams>): void;
}
export type OnnxState
= BrushSlice
& DefaultSlice
& HistorySlice
& Img2ImgSlice
& InpaintSlice
& OutpaintSlice
& Txt2ImgSlice
& UpscaleSlice;
export function createStateSlices(base: ConfigParams) {
const defaults = paramsFromConfig(base);
@ -220,6 +235,23 @@ export function createStateSlices(base: ConfigParams) {
},
});
const createUpscaleSlice: StateCreator<OnnxState, [], [], UpscaleSlice> = (set) => ({
upscale: {
denoise: 0.5,
enabled: false,
faces: false,
scale: 1,
},
setUpscale(upscale) {
set((prev) => ({
upscale: {
...prev.upscale,
...upscale,
}
}));
},
});
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({
defaults: {
...defaults,
@ -235,13 +267,14 @@ export function createStateSlices(base: ConfigParams) {
});
return {
createBrushSlice,
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
createOutpaintSlice,
createBrushSlice,
createTxt2ImgSlice,
createUpscaleSlice,
};
}