1
0
Fork 0

feat(gui): add menus for upscaling and correction models

This commit is contained in:
Sean Sube 2023-01-16 20:11:10 -06:00
parent ee6308a091
commit 0080d86d91
13 changed files with 247 additions and 163 deletions

View File

@ -2,22 +2,23 @@ import { doesExist } from '@apextoaster/js-utils';
import { ConfigParams } from './config.js';
export interface BaseImgParams {
export interface ModelParams {
/**
* Which ONNX model to use.
*/
model?: string;
model: string;
/**
* Hardware accelerator or CPU mode.
*/
platform?: string;
platform: string;
/**
* Scheduling algorithm.
*/
scheduler?: string;
upscaling: string;
correction: string;
}
export interface BaseImgParams {
scheduler: string;
prompt: string;
negativePrompt?: string;
@ -90,18 +91,24 @@ export interface ApiReady {
ready: boolean;
}
export interface ApiModels {
diffusion: Array<string>;
correction: Array<string>;
upscaling: Array<string>;
}
export interface ApiClient {
masks(): Promise<Array<string>>;
models(): Promise<Array<string>>;
models(): Promise<ApiModels>;
noises(): Promise<Array<string>>;
params(): Promise<ConfigParams>;
platforms(): Promise<Array<string>>;
schedulers(): Promise<Array<string>>;
img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
inpaint(params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
outpaint(params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
ready(params: ApiResponse): Promise<ApiReady>;
}
@ -111,9 +118,7 @@ export const STATUS_SUCCESS = 200;
export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams> {
return {
cfg: defaults.cfg.default,
model: defaults.model.default,
negativePrompt: defaults.negativePrompt.default,
platform: defaults.platform.default,
prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default,
steps: defaults.steps.default,
@ -141,14 +146,6 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler);
}
@ -167,6 +164,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
return url;
}
export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('model', params.model);
url.searchParams.append('platform', params.platform);
}
export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
if (upscale.enabled) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
@ -191,10 +193,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async models(): Promise<Array<string>> {
async models(): Promise<ApiModels> {
const path = makeApiUrl(root, 'settings', 'models');
const res = await f(path);
return await res.json() as Array<string>;
return await res.json() as ApiModels;
},
async noises(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'noises');
@ -216,12 +218,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) {
return pending;
}
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
if (doesExist(upscale)) {
@ -239,12 +243,13 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) {
return pending;
}
const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);
if (doesExist(params.width)) {
url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER));
@ -265,14 +270,17 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async inpaint(params: InpaintParams, upscale?: UpscaleParams) {
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) {
return pending;
}
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise);
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
@ -289,12 +297,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async outpaint(params: OutpaintParams, upscale?: UpscaleParams) {
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) {
return pending;
}
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise);

View File

@ -41,7 +41,9 @@ export function ImageControl(props: ImageControlProps) {
id='schedulers'
labels={SCHEDULER_LABELS}
name='Scheduler'
result={schedulers}
query={{
result: schedulers,
}}
value={mustDefault(params.scheduler, '')}
onChange={(value) => {
if (doesExist(props.onChange)) {

View File

@ -4,8 +4,8 @@ import * as React from 'react';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { ConfigParams, IMAGE_FILTER } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { IMAGE_FILTER } from '../config.js';
import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
import { NumericField } from './NumericField.js';
@ -13,23 +13,14 @@ import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React;
export interface Img2ImgProps {
config: ConfigParams;
model: string;
platform: string;
}
export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props;
export function Img2Img() {
const config = mustExist(useContext(ConfigContext));
async function uploadSource() {
const { img2img, upscale } = state.getState();
const { model, img2img, upscale } = state.getState();
const output = await client.img2img({
const output = await client.img2img(model, {
...img2img,
model,
platform,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, upscale);

View File

@ -4,8 +4,8 @@ import * as React from 'react';
import { useMutation, useQuery, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { ConfigParams, IMAGE_FILTER, STALE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { IMAGE_FILTER, STALE_TIME } from '../config.js';
import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { MASK_LABELS, NOISE_LABELS } from '../strings.js';
import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js';
@ -16,16 +16,10 @@ import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React;
export interface InpaintProps {
config: ConfigParams;
model: string;
platform: string;
}
export function Inpaint(props: InpaintProps) {
const { config, model, platform } = props;
export function Inpaint() {
const config = mustExist(useContext(ConfigContext));
const client = mustExist(useContext(ClientContext));
const masks = useQuery('masks', async () => client.masks(), {
staleTime: STALE_TIME,
});
@ -35,24 +29,20 @@ export function Inpaint(props: InpaintProps) {
async function uploadSource(): Promise<void> {
// these are not watched by the component, only sent by the mutation
const { inpaint, outpaint, upscale } = state.getState();
const { model, inpaint, outpaint, upscale } = state.getState();
if (outpaint.enabled) {
const output = await client.outpaint({
const output = await client.outpaint(model, {
...inpaint,
...outpaint,
model,
platform,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);
setLoading(output);
} else {
const output = await client.inpaint({
const output = await client.inpaint(model, {
...inpaint,
model,
platform,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);
@ -122,7 +112,9 @@ export function Inpaint(props: InpaintProps) {
id='masks'
labels={MASK_LABELS}
name='Mask Filter'
result={masks}
query={{
result: masks,
}}
value={filter}
onChange={(newFilter) => {
setInpaint({
@ -134,7 +126,9 @@ export function Inpaint(props: InpaintProps) {
id='noises'
labels={NOISE_LABELS}
name='Noise Source'
result={noises}
query={{
result: noises,
}}
value={noise}
onChange={(newNoise) => {
setInpaint({

View File

@ -0,0 +1,89 @@
import { mustExist } from '@apextoaster/js-utils';
import { Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useStore } from 'zustand';
import { STALE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
import { QueryList } from './QueryList.js';
export function ModelControl() {
const client = mustExist(useContext(ClientContext));
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.model);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setModel);
const models = useQuery('models', async () => client.models(), {
staleTime: STALE_TIME,
});
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});
return <Stack direction='row' spacing={2}>
<QueryList
id='platforms'
labels={PLATFORM_LABELS}
name='Platform'
query={{
result: platforms,
}}
value={params.platform}
onChange={(platform) => {
setModel({
platform,
});
}}
/>
<QueryList
id='diffusion'
labels={MODEL_LABELS}
name='Diffusion Model'
query={{
result: models,
selector: (result) => result.diffusion,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
<QueryList
id='upscaling'
labels={MODEL_LABELS}
name='Upscaling Model'
query={{
result: models,
selector: (result) => result.upscaling,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
<QueryList
id='correction'
labels={MODEL_LABELS}
name='Correction Model'
query={{
result: models,
selector: (result) => result.correction,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
</Stack>;
}

View File

@ -1,41 +1,18 @@
import { mustExist } from '@apextoaster/js-utils';
import { TabContext, TabList, TabPanel } from '@mui/lab';
import { Box, Container, Divider, Link, Stack, Tab, Typography } from '@mui/material';
import { Box, Container, Divider, Link, Tab, Typography } from '@mui/material';
import * as React from 'react';
import { useQuery } from 'react-query';
import { ApiClient } from '../client.js';
import { ConfigParams, STALE_TIME } from '../config.js';
import { ClientContext } from '../state.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
import { ImageHistory } from './ImageHistory.js';
import { Img2Img } from './Img2Img.js';
import { Inpaint } from './Inpaint.js';
import { QueryList } from './QueryList.js';
import { ModelControl } from './ModelControl.js';
import { Settings } from './Settings.js';
import { Txt2Img } from './Txt2Img.js';
const { useContext, useState } = React;
const { useState } = React;
export interface OnnxWebProps {
client: ApiClient;
config: ConfigParams;
}
export function OnnxWeb(props: OnnxWebProps) {
const { config } = props;
const client = mustExist(useContext(ClientContext));
export function OnnxWeb() {
const [tab, setTab] = useState('txt2img');
const [model, setModel] = useState(config.model.default);
const [platform, setPlatform] = useState(config.platform.default);
const models = useQuery('models', async () => client.models(), {
staleTime: STALE_TIME,
});
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});
return (
<Container>
@ -45,28 +22,7 @@ export function OnnxWeb(props: OnnxWebProps) {
</Typography>
</Box>
<Box sx={{ mx: 4, my: 4 }}>
<Stack direction='row' spacing={2}>
<QueryList
id='models'
labels={MODEL_LABELS}
name='Model'
result={models}
value={model}
onChange={(value) => {
setModel(value);
}}
/>
<QueryList
id='platforms'
labels={PLATFORM_LABELS}
name='Platform'
result={platforms}
value={platform}
onChange={(value) => {
setPlatform(value);
}}
/>
</Stack>
<ModelControl />
</Box>
<TabContext value={tab}>
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
@ -80,16 +36,16 @@ export function OnnxWeb(props: OnnxWebProps) {
</TabList>
</Box>
<TabPanel value='txt2img'>
<Txt2Img config={config} model={model} platform={platform} />
<Txt2Img />
</TabPanel>
<TabPanel value='img2img'>
<Img2Img config={config} model={model} platform={platform} />
<Img2Img />
</TabPanel>
<TabPanel value='inpaint'>
<Inpaint config={config} model={model} platform={platform} />
<Inpaint />
</TabPanel>
<TabPanel value='settings'>
<Settings config={config} />
<Settings />
</TabPanel>
</TabContext>
<Divider variant='middle' />

View File

@ -3,18 +3,42 @@ import { FormControl, InputLabel, MenuItem, Select } from '@mui/material';
import * as React from 'react';
import { UseQueryResult } from 'react-query';
export interface QueryListProps {
export interface QueryListComplete {
result: UseQueryResult<Array<string>>;
}
export interface QueryListFilter<T> {
result: UseQueryResult<T>;
selector: (result: T) => Array<string>;
}
export interface QueryListProps<T> {
id: string;
labels: Record<string, string>;
name: string;
result: UseQueryResult<Array<string>>;
value: string;
query: QueryListComplete | QueryListFilter<T>;
onChange?: (value: string) => void;
}
export function QueryList(props: QueryListProps) {
const { labels, result, value } = props;
export function hasFilter<T>(query: QueryListComplete | QueryListFilter<T>): query is QueryListFilter<T> {
return Reflect.has(query, 'selector');
}
export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>): Array<string> {
if (hasFilter(query)) {
const data = mustExist(query.result.data);
return (query as QueryListFilter<unknown>).selector(data);
} else {
return mustExist(query.result.data);
}
}
export function QueryList<T>(props: QueryListProps<T>) {
const { labels, query, value } = props;
const { result } = query;
if (result.status === 'error') {
if (result.error instanceof Error) {
@ -34,7 +58,8 @@ export function QueryList(props: QueryListProps) {
// else: success
const labelID = `query-list-${props.id}-labels`;
const data = mustExist(result.data);
const data = filterQuery(query);
return <FormControl>
<InputLabel id={labelID}>{props.name}</InputLabel>
<Select

View File

@ -1,19 +1,13 @@
import { mustExist } from '@apextoaster/js-utils';
import { Button, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useStore } from 'zustand';
import { ConfigParams } from '../config.js';
import { StateContext } from '../state.js';
import { NumericField } from './NumericField.js';
const { useContext } = React;
export interface SettingsProps {
config: ConfigParams;
}
export function Settings(_props: SettingsProps) {
export function Settings() {
const state = useStore(mustExist(useContext(StateContext)));
return <Stack spacing={2}>
@ -25,16 +19,6 @@ export function Settings(_props: SettingsProps) {
value={state.limit}
onChange={(value) => state.setLimit(value)}
/>
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {
state.setDefaults({
model: event.target.value,
});
}} />
<TextField variant='outlined' label='Default Platform' value={state.defaults.platform} onChange={(event) => {
state.setDefaults({
platform: event.target.value,
});
}} />
<TextField variant='outlined' label='Default Prompt' value={state.defaults.prompt} onChange={(event) => {
state.setDefaults({
prompt: event.target.value,

View File

@ -4,31 +4,19 @@ import * as React from 'react';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { ConfigParams } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js';
import { NumericField } from './NumericField.js';
import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React;
export interface Txt2ImgProps {
config: ConfigParams;
model: string;
platform: string;
}
export function Txt2Img(props: Txt2ImgProps) {
const { config, model, platform } = props;
export function Txt2Img() {
const config = mustExist(useContext(ConfigContext));
async function generateImage() {
const { txt2img, upscale } = state.getState();
const output = await client.txt2img({
...txt2img,
model,
platform,
}, upscale);
const { model, txt2img, upscale } = state.getState();
const output = await client.txt2img(model, txt2img, upscale);
setLoading(output);
}

View File

@ -1,6 +1,6 @@
import { Maybe } from '@apextoaster/js-utils';
import { Img2ImgParams, InpaintParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js';
import { Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js';
export interface ConfigNumber {
default: number;
@ -30,7 +30,16 @@ export type ConfigState<T extends object, TValid = number | string> = {
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
};
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams & InpaintParams & OutpaintParams & UpscaleParams>>;
/* eslint-disable */
export type ConfigParams = ConfigRanges<Required<
Img2ImgParams &
Txt2ImgParams &
InpaintParams &
ModelParams &
OutpaintParams &
UpscaleParams
>>;
/* eslint-enable */
export interface Config {
api: {

View File

@ -11,7 +11,7 @@ import { makeClient } from './client.js';
import { OnnxError } from './components/OnnxError.js';
import { OnnxWeb } from './components/OnnxWeb.js';
import { Config, loadConfig } from './config.js';
import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js';
import { ClientContext, ConfigContext, createStateSlices, OnnxState, StateContext } from './state.js';
export function getApiRoot(config: Config): string {
const query = new URLSearchParams(window.location.search);
@ -48,6 +48,7 @@ export async function main() {
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createModelSlice,
createOutpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
@ -58,6 +59,7 @@ export async function main() {
...createHistorySlice(...slice),
...createImg2ImgSlice(...slice),
...createInpaintSlice(...slice),
...createModelSlice(...slice),
...createTxt2ImgSlice(...slice),
...createOutpaintSlice(...slice),
...createUpscaleSlice(...slice),
@ -87,9 +89,11 @@ export async function main() {
// go
app.render(<QueryClientProvider client={query}>
<ClientContext.Provider value={client}>
<StateContext.Provider value={state}>
<OnnxWeb client={client} config={params} />
</StateContext.Provider>
<ConfigContext.Provider value={params}>
<StateContext.Provider value={state}>
<OnnxWeb />
</StateContext.Provider>
</ConfigContext.Provider>
</ClientContext.Provider>
</QueryClientProvider>);
} catch (err) {

View File

@ -10,6 +10,7 @@ import {
BrushParams,
Img2ImgParams,
InpaintParams,
ModelParams,
OutpaintPixels,
paramsFromConfig,
Txt2ImgParams,
@ -75,12 +76,19 @@ interface UpscaleSlice {
setUpscale(upscale: Partial<UpscaleParams>): void;
}
interface ModelSlice {
model: ModelParams;
setModel(model: Partial<ModelParams>): void;
}
export type OnnxState
= BrushSlice
& DefaultSlice
& HistorySlice
& Img2ImgSlice
& InpaintSlice
& ModelSlice
& OutpaintSlice
& Txt2ImgSlice
& UpscaleSlice;
@ -267,12 +275,30 @@ export function createStateSlices(base: ConfigParams) {
},
});
const createModelSlice: StateCreator<OnnxState, [], [], ModelSlice> = (set) => ({
model: {
model: '',
platform: '',
upscaling: '',
correction: '',
},
setModel(params) {
set((prev) => ({
model: {
...prev.model,
...params,
}
}));
},
});
return {
createBrushSlice,
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createModelSlice,
createOutpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
@ -280,4 +306,5 @@ export function createStateSlices(base: ConfigParams) {
}
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
export const ConfigContext = createContext<Maybe<ConfigParams>>(undefined);
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);

View File

@ -19,6 +19,7 @@
"ddpm",
"denoise",
"directml",
"ESRGAN",
"ftfy",
"gfpgan",
"Heun",
@ -33,20 +34,24 @@
"numpy",
"Onnx",
"onnxruntime",
"opset",
"outpaint",
"outscale",
"pndm",
"pretrained",
"protobuf",
"resrgan",
"RRDB",
"runwayml",
"scandir",
"scipy",
"Singlestep",
"spacy",
"spinalcase",
"stabilityai",
"stringcase",
"upsampler",
"upscaling",
"venv",
"virtualenv",
"zustand"