feat(gui): add menus for upscaling and correction models
This commit is contained in:
parent
ee6308a091
commit
0080d86d91
|
@ -2,22 +2,23 @@ import { doesExist } from '@apextoaster/js-utils';
|
||||||
|
|
||||||
import { ConfigParams } from './config.js';
|
import { ConfigParams } from './config.js';
|
||||||
|
|
||||||
export interface BaseImgParams {
|
export interface ModelParams {
|
||||||
/**
|
/**
|
||||||
* Which ONNX model to use.
|
* Which ONNX model to use.
|
||||||
*/
|
*/
|
||||||
model?: string;
|
model: string;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hardware accelerator or CPU mode.
|
* Hardware accelerator or CPU mode.
|
||||||
*/
|
*/
|
||||||
platform?: string;
|
platform: string;
|
||||||
|
|
||||||
/**
|
upscaling: string;
|
||||||
* Scheduling algorithm.
|
correction: string;
|
||||||
*/
|
}
|
||||||
scheduler?: string;
|
|
||||||
|
|
||||||
|
export interface BaseImgParams {
|
||||||
|
scheduler: string;
|
||||||
prompt: string;
|
prompt: string;
|
||||||
negativePrompt?: string;
|
negativePrompt?: string;
|
||||||
|
|
||||||
|
@ -90,18 +91,24 @@ export interface ApiReady {
|
||||||
ready: boolean;
|
ready: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ApiModels {
|
||||||
|
diffusion: Array<string>;
|
||||||
|
correction: Array<string>;
|
||||||
|
upscaling: Array<string>;
|
||||||
|
}
|
||||||
|
|
||||||
export interface ApiClient {
|
export interface ApiClient {
|
||||||
masks(): Promise<Array<string>>;
|
masks(): Promise<Array<string>>;
|
||||||
models(): Promise<Array<string>>;
|
models(): Promise<ApiModels>;
|
||||||
noises(): Promise<Array<string>>;
|
noises(): Promise<Array<string>>;
|
||||||
params(): Promise<ConfigParams>;
|
params(): Promise<ConfigParams>;
|
||||||
platforms(): Promise<Array<string>>;
|
platforms(): Promise<Array<string>>;
|
||||||
schedulers(): Promise<Array<string>>;
|
schedulers(): Promise<Array<string>>;
|
||||||
|
|
||||||
img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||||
txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||||
inpaint(params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||||
outpaint(params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||||
|
|
||||||
ready(params: ApiResponse): Promise<ApiReady>;
|
ready(params: ApiResponse): Promise<ApiReady>;
|
||||||
}
|
}
|
||||||
|
@ -111,9 +118,7 @@ export const STATUS_SUCCESS = 200;
|
||||||
export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams> {
|
export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams> {
|
||||||
return {
|
return {
|
||||||
cfg: defaults.cfg.default,
|
cfg: defaults.cfg.default,
|
||||||
model: defaults.model.default,
|
|
||||||
negativePrompt: defaults.negativePrompt.default,
|
negativePrompt: defaults.negativePrompt.default,
|
||||||
platform: defaults.platform.default,
|
|
||||||
prompt: defaults.prompt.default,
|
prompt: defaults.prompt.default,
|
||||||
scheduler: defaults.scheduler.default,
|
scheduler: defaults.scheduler.default,
|
||||||
steps: defaults.steps.default,
|
steps: defaults.steps.default,
|
||||||
|
@ -141,14 +146,6 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
|
||||||
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
|
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
|
||||||
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
|
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
|
||||||
|
|
||||||
if (doesExist(params.model)) {
|
|
||||||
url.searchParams.append('model', params.model);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (doesExist(params.platform)) {
|
|
||||||
url.searchParams.append('platform', params.platform);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (doesExist(params.scheduler)) {
|
if (doesExist(params.scheduler)) {
|
||||||
url.searchParams.append('scheduler', params.scheduler);
|
url.searchParams.append('scheduler', params.scheduler);
|
||||||
}
|
}
|
||||||
|
@ -167,6 +164,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
|
||||||
return url;
|
return url;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function appendModelToURL(url: URL, params: ModelParams) {
|
||||||
|
url.searchParams.append('model', params.model);
|
||||||
|
url.searchParams.append('platform', params.platform);
|
||||||
|
}
|
||||||
|
|
||||||
export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
|
export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
|
||||||
if (upscale.enabled) {
|
if (upscale.enabled) {
|
||||||
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
|
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
|
||||||
|
@ -191,10 +193,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return await res.json() as Array<string>;
|
return await res.json() as Array<string>;
|
||||||
},
|
},
|
||||||
async models(): Promise<Array<string>> {
|
async models(): Promise<ApiModels> {
|
||||||
const path = makeApiUrl(root, 'settings', 'models');
|
const path = makeApiUrl(root, 'settings', 'models');
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return await res.json() as Array<string>;
|
return await res.json() as ApiModels;
|
||||||
},
|
},
|
||||||
async noises(): Promise<Array<string>> {
|
async noises(): Promise<Array<string>> {
|
||||||
const path = makeApiUrl(root, 'settings', 'noises');
|
const path = makeApiUrl(root, 'settings', 'noises');
|
||||||
|
@ -216,12 +218,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return await res.json() as Array<string>;
|
return await res.json() as Array<string>;
|
||||||
},
|
},
|
||||||
async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||||
if (doesExist(pending)) {
|
if (doesExist(pending)) {
|
||||||
return pending;
|
return pending;
|
||||||
}
|
}
|
||||||
|
|
||||||
const url = makeImageURL(root, 'img2img', params);
|
const url = makeImageURL(root, 'img2img', params);
|
||||||
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
|
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
|
||||||
|
|
||||||
if (doesExist(upscale)) {
|
if (doesExist(upscale)) {
|
||||||
|
@ -239,12 +243,13 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
// eslint-disable-next-line no-return-await
|
// eslint-disable-next-line no-return-await
|
||||||
return await pending;
|
return await pending;
|
||||||
},
|
},
|
||||||
async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||||
if (doesExist(pending)) {
|
if (doesExist(pending)) {
|
||||||
return pending;
|
return pending;
|
||||||
}
|
}
|
||||||
|
|
||||||
const url = makeImageURL(root, 'txt2img', params);
|
const url = makeImageURL(root, 'txt2img', params);
|
||||||
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
if (doesExist(params.width)) {
|
if (doesExist(params.width)) {
|
||||||
url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER));
|
url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER));
|
||||||
|
@ -265,14 +270,17 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
// eslint-disable-next-line no-return-await
|
// eslint-disable-next-line no-return-await
|
||||||
return await pending;
|
return await pending;
|
||||||
},
|
},
|
||||||
async inpaint(params: InpaintParams, upscale?: UpscaleParams) {
|
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) {
|
||||||
if (doesExist(pending)) {
|
if (doesExist(pending)) {
|
||||||
return pending;
|
return pending;
|
||||||
}
|
}
|
||||||
|
|
||||||
const url = makeImageURL(root, 'inpaint', params);
|
const url = makeImageURL(root, 'inpaint', params);
|
||||||
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
url.searchParams.append('filter', params.filter);
|
url.searchParams.append('filter', params.filter);
|
||||||
url.searchParams.append('noise', params.noise);
|
url.searchParams.append('noise', params.noise);
|
||||||
|
|
||||||
if (doesExist(upscale)) {
|
if (doesExist(upscale)) {
|
||||||
appendUpscaleToURL(url, upscale);
|
appendUpscaleToURL(url, upscale);
|
||||||
}
|
}
|
||||||
|
@ -289,12 +297,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
// eslint-disable-next-line no-return-await
|
// eslint-disable-next-line no-return-await
|
||||||
return await pending;
|
return await pending;
|
||||||
},
|
},
|
||||||
async outpaint(params: OutpaintParams, upscale?: UpscaleParams) {
|
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) {
|
||||||
if (doesExist(pending)) {
|
if (doesExist(pending)) {
|
||||||
return pending;
|
return pending;
|
||||||
}
|
}
|
||||||
|
|
||||||
const url = makeImageURL(root, 'inpaint', params);
|
const url = makeImageURL(root, 'inpaint', params);
|
||||||
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
url.searchParams.append('filter', params.filter);
|
url.searchParams.append('filter', params.filter);
|
||||||
url.searchParams.append('noise', params.noise);
|
url.searchParams.append('noise', params.noise);
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,9 @@ export function ImageControl(props: ImageControlProps) {
|
||||||
id='schedulers'
|
id='schedulers'
|
||||||
labels={SCHEDULER_LABELS}
|
labels={SCHEDULER_LABELS}
|
||||||
name='Scheduler'
|
name='Scheduler'
|
||||||
result={schedulers}
|
query={{
|
||||||
|
result: schedulers,
|
||||||
|
}}
|
||||||
value={mustDefault(params.scheduler, '')}
|
value={mustDefault(params.scheduler, '')}
|
||||||
onChange={(value) => {
|
onChange={(value) => {
|
||||||
if (doesExist(props.onChange)) {
|
if (doesExist(props.onChange)) {
|
||||||
|
|
|
@ -4,8 +4,8 @@ import * as React from 'react';
|
||||||
import { useMutation, useQueryClient } from 'react-query';
|
import { useMutation, useQueryClient } from 'react-query';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigParams, IMAGE_FILTER } from '../config.js';
|
import { IMAGE_FILTER } from '../config.js';
|
||||||
import { ClientContext, StateContext } from '../state.js';
|
import { ClientContext, ConfigContext, StateContext } from '../state.js';
|
||||||
import { ImageControl } from './ImageControl.js';
|
import { ImageControl } from './ImageControl.js';
|
||||||
import { ImageInput } from './ImageInput.js';
|
import { ImageInput } from './ImageInput.js';
|
||||||
import { NumericField } from './NumericField.js';
|
import { NumericField } from './NumericField.js';
|
||||||
|
@ -13,23 +13,14 @@ import { UpscaleControl } from './UpscaleControl.js';
|
||||||
|
|
||||||
const { useContext } = React;
|
const { useContext } = React;
|
||||||
|
|
||||||
export interface Img2ImgProps {
|
export function Img2Img() {
|
||||||
config: ConfigParams;
|
const config = mustExist(useContext(ConfigContext));
|
||||||
|
|
||||||
model: string;
|
|
||||||
platform: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Img2Img(props: Img2ImgProps) {
|
|
||||||
const { config, model, platform } = props;
|
|
||||||
|
|
||||||
async function uploadSource() {
|
async function uploadSource() {
|
||||||
const { img2img, upscale } = state.getState();
|
const { model, img2img, upscale } = state.getState();
|
||||||
|
|
||||||
const output = await client.img2img({
|
const output = await client.img2img(model, {
|
||||||
...img2img,
|
...img2img,
|
||||||
model,
|
|
||||||
platform,
|
|
||||||
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
||||||
}, upscale);
|
}, upscale);
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,8 @@ import * as React from 'react';
|
||||||
import { useMutation, useQuery, useQueryClient } from 'react-query';
|
import { useMutation, useQuery, useQueryClient } from 'react-query';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigParams, IMAGE_FILTER, STALE_TIME } from '../config.js';
|
import { IMAGE_FILTER, STALE_TIME } from '../config.js';
|
||||||
import { ClientContext, StateContext } from '../state.js';
|
import { ClientContext, ConfigContext, StateContext } from '../state.js';
|
||||||
import { MASK_LABELS, NOISE_LABELS } from '../strings.js';
|
import { MASK_LABELS, NOISE_LABELS } from '../strings.js';
|
||||||
import { ImageControl } from './ImageControl.js';
|
import { ImageControl } from './ImageControl.js';
|
||||||
import { ImageInput } from './ImageInput.js';
|
import { ImageInput } from './ImageInput.js';
|
||||||
|
@ -16,16 +16,10 @@ import { UpscaleControl } from './UpscaleControl.js';
|
||||||
|
|
||||||
const { useContext } = React;
|
const { useContext } = React;
|
||||||
|
|
||||||
export interface InpaintProps {
|
export function Inpaint() {
|
||||||
config: ConfigParams;
|
const config = mustExist(useContext(ConfigContext));
|
||||||
|
|
||||||
model: string;
|
|
||||||
platform: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Inpaint(props: InpaintProps) {
|
|
||||||
const { config, model, platform } = props;
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
|
|
||||||
const masks = useQuery('masks', async () => client.masks(), {
|
const masks = useQuery('masks', async () => client.masks(), {
|
||||||
staleTime: STALE_TIME,
|
staleTime: STALE_TIME,
|
||||||
});
|
});
|
||||||
|
@ -35,24 +29,20 @@ export function Inpaint(props: InpaintProps) {
|
||||||
|
|
||||||
async function uploadSource(): Promise<void> {
|
async function uploadSource(): Promise<void> {
|
||||||
// these are not watched by the component, only sent by the mutation
|
// these are not watched by the component, only sent by the mutation
|
||||||
const { inpaint, outpaint, upscale } = state.getState();
|
const { model, inpaint, outpaint, upscale } = state.getState();
|
||||||
|
|
||||||
if (outpaint.enabled) {
|
if (outpaint.enabled) {
|
||||||
const output = await client.outpaint({
|
const output = await client.outpaint(model, {
|
||||||
...inpaint,
|
...inpaint,
|
||||||
...outpaint,
|
...outpaint,
|
||||||
model,
|
|
||||||
platform,
|
|
||||||
mask: mustExist(mask),
|
mask: mustExist(mask),
|
||||||
source: mustExist(source),
|
source: mustExist(source),
|
||||||
}, upscale);
|
}, upscale);
|
||||||
|
|
||||||
setLoading(output);
|
setLoading(output);
|
||||||
} else {
|
} else {
|
||||||
const output = await client.inpaint({
|
const output = await client.inpaint(model, {
|
||||||
...inpaint,
|
...inpaint,
|
||||||
model,
|
|
||||||
platform,
|
|
||||||
mask: mustExist(mask),
|
mask: mustExist(mask),
|
||||||
source: mustExist(source),
|
source: mustExist(source),
|
||||||
}, upscale);
|
}, upscale);
|
||||||
|
@ -122,7 +112,9 @@ export function Inpaint(props: InpaintProps) {
|
||||||
id='masks'
|
id='masks'
|
||||||
labels={MASK_LABELS}
|
labels={MASK_LABELS}
|
||||||
name='Mask Filter'
|
name='Mask Filter'
|
||||||
result={masks}
|
query={{
|
||||||
|
result: masks,
|
||||||
|
}}
|
||||||
value={filter}
|
value={filter}
|
||||||
onChange={(newFilter) => {
|
onChange={(newFilter) => {
|
||||||
setInpaint({
|
setInpaint({
|
||||||
|
@ -134,7 +126,9 @@ export function Inpaint(props: InpaintProps) {
|
||||||
id='noises'
|
id='noises'
|
||||||
labels={NOISE_LABELS}
|
labels={NOISE_LABELS}
|
||||||
name='Noise Source'
|
name='Noise Source'
|
||||||
result={noises}
|
query={{
|
||||||
|
result: noises,
|
||||||
|
}}
|
||||||
value={noise}
|
value={noise}
|
||||||
onChange={(newNoise) => {
|
onChange={(newNoise) => {
|
||||||
setInpaint({
|
setInpaint({
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
|
import { Stack } from '@mui/material';
|
||||||
|
import * as React from 'react';
|
||||||
|
import { useContext } from 'react';
|
||||||
|
import { useQuery } from 'react-query';
|
||||||
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
|
import { STALE_TIME } from '../config.js';
|
||||||
|
import { ClientContext, StateContext } from '../state.js';
|
||||||
|
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
|
||||||
|
import { QueryList } from './QueryList.js';
|
||||||
|
|
||||||
|
export function ModelControl() {
|
||||||
|
const client = mustExist(useContext(ClientContext));
|
||||||
|
const state = mustExist(useContext(StateContext));
|
||||||
|
const params = useStore(state, (s) => s.model);
|
||||||
|
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||||
|
const setModel = useStore(state, (s) => s.setModel);
|
||||||
|
|
||||||
|
const models = useQuery('models', async () => client.models(), {
|
||||||
|
staleTime: STALE_TIME,
|
||||||
|
});
|
||||||
|
const platforms = useQuery('platforms', async () => client.platforms(), {
|
||||||
|
staleTime: STALE_TIME,
|
||||||
|
});
|
||||||
|
|
||||||
|
return <Stack direction='row' spacing={2}>
|
||||||
|
<QueryList
|
||||||
|
id='platforms'
|
||||||
|
labels={PLATFORM_LABELS}
|
||||||
|
name='Platform'
|
||||||
|
query={{
|
||||||
|
result: platforms,
|
||||||
|
}}
|
||||||
|
value={params.platform}
|
||||||
|
onChange={(platform) => {
|
||||||
|
setModel({
|
||||||
|
platform,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<QueryList
|
||||||
|
id='diffusion'
|
||||||
|
labels={MODEL_LABELS}
|
||||||
|
name='Diffusion Model'
|
||||||
|
query={{
|
||||||
|
result: models,
|
||||||
|
selector: (result) => result.diffusion,
|
||||||
|
}}
|
||||||
|
value={params.model}
|
||||||
|
onChange={(model) => {
|
||||||
|
setModel({
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<QueryList
|
||||||
|
id='upscaling'
|
||||||
|
labels={MODEL_LABELS}
|
||||||
|
name='Upscaling Model'
|
||||||
|
query={{
|
||||||
|
result: models,
|
||||||
|
selector: (result) => result.upscaling,
|
||||||
|
}}
|
||||||
|
value={params.model}
|
||||||
|
onChange={(model) => {
|
||||||
|
setModel({
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<QueryList
|
||||||
|
id='correction'
|
||||||
|
labels={MODEL_LABELS}
|
||||||
|
name='Correction Model'
|
||||||
|
query={{
|
||||||
|
result: models,
|
||||||
|
selector: (result) => result.correction,
|
||||||
|
}}
|
||||||
|
value={params.model}
|
||||||
|
onChange={(model) => {
|
||||||
|
setModel({
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
</Stack>;
|
||||||
|
}
|
|
@ -1,41 +1,18 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
|
||||||
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
||||||
import { Box, Container, Divider, Link, Stack, Tab, Typography } from '@mui/material';
|
import { Box, Container, Divider, Link, Tab, Typography } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useQuery } from 'react-query';
|
|
||||||
|
|
||||||
import { ApiClient } from '../client.js';
|
|
||||||
import { ConfigParams, STALE_TIME } from '../config.js';
|
|
||||||
import { ClientContext } from '../state.js';
|
|
||||||
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
|
|
||||||
import { ImageHistory } from './ImageHistory.js';
|
import { ImageHistory } from './ImageHistory.js';
|
||||||
import { Img2Img } from './Img2Img.js';
|
import { Img2Img } from './Img2Img.js';
|
||||||
import { Inpaint } from './Inpaint.js';
|
import { Inpaint } from './Inpaint.js';
|
||||||
import { QueryList } from './QueryList.js';
|
import { ModelControl } from './ModelControl.js';
|
||||||
import { Settings } from './Settings.js';
|
import { Settings } from './Settings.js';
|
||||||
import { Txt2Img } from './Txt2Img.js';
|
import { Txt2Img } from './Txt2Img.js';
|
||||||
|
|
||||||
const { useContext, useState } = React;
|
const { useState } = React;
|
||||||
|
|
||||||
export interface OnnxWebProps {
|
export function OnnxWeb() {
|
||||||
client: ApiClient;
|
|
||||||
config: ConfigParams;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function OnnxWeb(props: OnnxWebProps) {
|
|
||||||
const { config } = props;
|
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
|
||||||
const [tab, setTab] = useState('txt2img');
|
const [tab, setTab] = useState('txt2img');
|
||||||
const [model, setModel] = useState(config.model.default);
|
|
||||||
const [platform, setPlatform] = useState(config.platform.default);
|
|
||||||
|
|
||||||
const models = useQuery('models', async () => client.models(), {
|
|
||||||
staleTime: STALE_TIME,
|
|
||||||
});
|
|
||||||
const platforms = useQuery('platforms', async () => client.platforms(), {
|
|
||||||
staleTime: STALE_TIME,
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Container>
|
<Container>
|
||||||
|
@ -45,28 +22,7 @@ export function OnnxWeb(props: OnnxWebProps) {
|
||||||
</Typography>
|
</Typography>
|
||||||
</Box>
|
</Box>
|
||||||
<Box sx={{ mx: 4, my: 4 }}>
|
<Box sx={{ mx: 4, my: 4 }}>
|
||||||
<Stack direction='row' spacing={2}>
|
<ModelControl />
|
||||||
<QueryList
|
|
||||||
id='models'
|
|
||||||
labels={MODEL_LABELS}
|
|
||||||
name='Model'
|
|
||||||
result={models}
|
|
||||||
value={model}
|
|
||||||
onChange={(value) => {
|
|
||||||
setModel(value);
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<QueryList
|
|
||||||
id='platforms'
|
|
||||||
labels={PLATFORM_LABELS}
|
|
||||||
name='Platform'
|
|
||||||
result={platforms}
|
|
||||||
value={platform}
|
|
||||||
onChange={(value) => {
|
|
||||||
setPlatform(value);
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</Stack>
|
|
||||||
</Box>
|
</Box>
|
||||||
<TabContext value={tab}>
|
<TabContext value={tab}>
|
||||||
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
|
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
|
||||||
|
@ -80,16 +36,16 @@ export function OnnxWeb(props: OnnxWebProps) {
|
||||||
</TabList>
|
</TabList>
|
||||||
</Box>
|
</Box>
|
||||||
<TabPanel value='txt2img'>
|
<TabPanel value='txt2img'>
|
||||||
<Txt2Img config={config} model={model} platform={platform} />
|
<Txt2Img />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
<TabPanel value='img2img'>
|
<TabPanel value='img2img'>
|
||||||
<Img2Img config={config} model={model} platform={platform} />
|
<Img2Img />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
<TabPanel value='inpaint'>
|
<TabPanel value='inpaint'>
|
||||||
<Inpaint config={config} model={model} platform={platform} />
|
<Inpaint />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
<TabPanel value='settings'>
|
<TabPanel value='settings'>
|
||||||
<Settings config={config} />
|
<Settings />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
</TabContext>
|
</TabContext>
|
||||||
<Divider variant='middle' />
|
<Divider variant='middle' />
|
||||||
|
|
|
@ -3,18 +3,42 @@ import { FormControl, InputLabel, MenuItem, Select } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { UseQueryResult } from 'react-query';
|
import { UseQueryResult } from 'react-query';
|
||||||
|
|
||||||
export interface QueryListProps {
|
export interface QueryListComplete {
|
||||||
|
result: UseQueryResult<Array<string>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface QueryListFilter<T> {
|
||||||
|
result: UseQueryResult<T>;
|
||||||
|
selector: (result: T) => Array<string>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface QueryListProps<T> {
|
||||||
id: string;
|
id: string;
|
||||||
labels: Record<string, string>;
|
labels: Record<string, string>;
|
||||||
name: string;
|
name: string;
|
||||||
result: UseQueryResult<Array<string>>;
|
|
||||||
value: string;
|
value: string;
|
||||||
|
|
||||||
|
query: QueryListComplete | QueryListFilter<T>;
|
||||||
|
|
||||||
onChange?: (value: string) => void;
|
onChange?: (value: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function QueryList(props: QueryListProps) {
|
export function hasFilter<T>(query: QueryListComplete | QueryListFilter<T>): query is QueryListFilter<T> {
|
||||||
const { labels, result, value } = props;
|
return Reflect.has(query, 'selector');
|
||||||
|
}
|
||||||
|
|
||||||
|
export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>): Array<string> {
|
||||||
|
if (hasFilter(query)) {
|
||||||
|
const data = mustExist(query.result.data);
|
||||||
|
return (query as QueryListFilter<unknown>).selector(data);
|
||||||
|
} else {
|
||||||
|
return mustExist(query.result.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function QueryList<T>(props: QueryListProps<T>) {
|
||||||
|
const { labels, query, value } = props;
|
||||||
|
const { result } = query;
|
||||||
|
|
||||||
if (result.status === 'error') {
|
if (result.status === 'error') {
|
||||||
if (result.error instanceof Error) {
|
if (result.error instanceof Error) {
|
||||||
|
@ -34,7 +58,8 @@ export function QueryList(props: QueryListProps) {
|
||||||
|
|
||||||
// else: success
|
// else: success
|
||||||
const labelID = `query-list-${props.id}-labels`;
|
const labelID = `query-list-${props.id}-labels`;
|
||||||
const data = mustExist(result.data);
|
const data = filterQuery(query);
|
||||||
|
|
||||||
return <FormControl>
|
return <FormControl>
|
||||||
<InputLabel id={labelID}>{props.name}</InputLabel>
|
<InputLabel id={labelID}>{props.name}</InputLabel>
|
||||||
<Select
|
<Select
|
||||||
|
|
|
@ -1,19 +1,13 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
import { Button, Stack, TextField } from '@mui/material';
|
import { Button, Stack, TextField } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
|
import { useContext } from 'react';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigParams } from '../config.js';
|
|
||||||
import { StateContext } from '../state.js';
|
import { StateContext } from '../state.js';
|
||||||
import { NumericField } from './NumericField.js';
|
import { NumericField } from './NumericField.js';
|
||||||
|
|
||||||
const { useContext } = React;
|
export function Settings() {
|
||||||
|
|
||||||
export interface SettingsProps {
|
|
||||||
config: ConfigParams;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Settings(_props: SettingsProps) {
|
|
||||||
const state = useStore(mustExist(useContext(StateContext)));
|
const state = useStore(mustExist(useContext(StateContext)));
|
||||||
|
|
||||||
return <Stack spacing={2}>
|
return <Stack spacing={2}>
|
||||||
|
@ -25,16 +19,6 @@ export function Settings(_props: SettingsProps) {
|
||||||
value={state.limit}
|
value={state.limit}
|
||||||
onChange={(value) => state.setLimit(value)}
|
onChange={(value) => state.setLimit(value)}
|
||||||
/>
|
/>
|
||||||
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {
|
|
||||||
state.setDefaults({
|
|
||||||
model: event.target.value,
|
|
||||||
});
|
|
||||||
}} />
|
|
||||||
<TextField variant='outlined' label='Default Platform' value={state.defaults.platform} onChange={(event) => {
|
|
||||||
state.setDefaults({
|
|
||||||
platform: event.target.value,
|
|
||||||
});
|
|
||||||
}} />
|
|
||||||
<TextField variant='outlined' label='Default Prompt' value={state.defaults.prompt} onChange={(event) => {
|
<TextField variant='outlined' label='Default Prompt' value={state.defaults.prompt} onChange={(event) => {
|
||||||
state.setDefaults({
|
state.setDefaults({
|
||||||
prompt: event.target.value,
|
prompt: event.target.value,
|
||||||
|
|
|
@ -4,31 +4,19 @@ import * as React from 'react';
|
||||||
import { useMutation, useQueryClient } from 'react-query';
|
import { useMutation, useQueryClient } from 'react-query';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigParams } from '../config.js';
|
import { ClientContext, ConfigContext, StateContext } from '../state.js';
|
||||||
import { ClientContext, StateContext } from '../state.js';
|
|
||||||
import { ImageControl } from './ImageControl.js';
|
import { ImageControl } from './ImageControl.js';
|
||||||
import { NumericField } from './NumericField.js';
|
import { NumericField } from './NumericField.js';
|
||||||
import { UpscaleControl } from './UpscaleControl.js';
|
import { UpscaleControl } from './UpscaleControl.js';
|
||||||
|
|
||||||
const { useContext } = React;
|
const { useContext } = React;
|
||||||
|
|
||||||
export interface Txt2ImgProps {
|
export function Txt2Img() {
|
||||||
config: ConfigParams;
|
const config = mustExist(useContext(ConfigContext));
|
||||||
|
|
||||||
model: string;
|
|
||||||
platform: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Txt2Img(props: Txt2ImgProps) {
|
|
||||||
const { config, model, platform } = props;
|
|
||||||
|
|
||||||
async function generateImage() {
|
async function generateImage() {
|
||||||
const { txt2img, upscale } = state.getState();
|
const { model, txt2img, upscale } = state.getState();
|
||||||
const output = await client.txt2img({
|
const output = await client.txt2img(model, txt2img, upscale);
|
||||||
...txt2img,
|
|
||||||
model,
|
|
||||||
platform,
|
|
||||||
}, upscale);
|
|
||||||
|
|
||||||
setLoading(output);
|
setLoading(output);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import { Maybe } from '@apextoaster/js-utils';
|
import { Maybe } from '@apextoaster/js-utils';
|
||||||
|
|
||||||
import { Img2ImgParams, InpaintParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js';
|
import { Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js';
|
||||||
|
|
||||||
export interface ConfigNumber {
|
export interface ConfigNumber {
|
||||||
default: number;
|
default: number;
|
||||||
|
@ -30,7 +30,16 @@ export type ConfigState<T extends object, TValid = number | string> = {
|
||||||
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
|
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams & InpaintParams & OutpaintParams & UpscaleParams>>;
|
/* eslint-disable */
|
||||||
|
export type ConfigParams = ConfigRanges<Required<
|
||||||
|
Img2ImgParams &
|
||||||
|
Txt2ImgParams &
|
||||||
|
InpaintParams &
|
||||||
|
ModelParams &
|
||||||
|
OutpaintParams &
|
||||||
|
UpscaleParams
|
||||||
|
>>;
|
||||||
|
/* eslint-enable */
|
||||||
|
|
||||||
export interface Config {
|
export interface Config {
|
||||||
api: {
|
api: {
|
||||||
|
|
|
@ -11,7 +11,7 @@ import { makeClient } from './client.js';
|
||||||
import { OnnxError } from './components/OnnxError.js';
|
import { OnnxError } from './components/OnnxError.js';
|
||||||
import { OnnxWeb } from './components/OnnxWeb.js';
|
import { OnnxWeb } from './components/OnnxWeb.js';
|
||||||
import { Config, loadConfig } from './config.js';
|
import { Config, loadConfig } from './config.js';
|
||||||
import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js';
|
import { ClientContext, ConfigContext, createStateSlices, OnnxState, StateContext } from './state.js';
|
||||||
|
|
||||||
export function getApiRoot(config: Config): string {
|
export function getApiRoot(config: Config): string {
|
||||||
const query = new URLSearchParams(window.location.search);
|
const query = new URLSearchParams(window.location.search);
|
||||||
|
@ -48,6 +48,7 @@ export async function main() {
|
||||||
createHistorySlice,
|
createHistorySlice,
|
||||||
createImg2ImgSlice,
|
createImg2ImgSlice,
|
||||||
createInpaintSlice,
|
createInpaintSlice,
|
||||||
|
createModelSlice,
|
||||||
createOutpaintSlice,
|
createOutpaintSlice,
|
||||||
createTxt2ImgSlice,
|
createTxt2ImgSlice,
|
||||||
createUpscaleSlice,
|
createUpscaleSlice,
|
||||||
|
@ -58,6 +59,7 @@ export async function main() {
|
||||||
...createHistorySlice(...slice),
|
...createHistorySlice(...slice),
|
||||||
...createImg2ImgSlice(...slice),
|
...createImg2ImgSlice(...slice),
|
||||||
...createInpaintSlice(...slice),
|
...createInpaintSlice(...slice),
|
||||||
|
...createModelSlice(...slice),
|
||||||
...createTxt2ImgSlice(...slice),
|
...createTxt2ImgSlice(...slice),
|
||||||
...createOutpaintSlice(...slice),
|
...createOutpaintSlice(...slice),
|
||||||
...createUpscaleSlice(...slice),
|
...createUpscaleSlice(...slice),
|
||||||
|
@ -87,9 +89,11 @@ export async function main() {
|
||||||
// go
|
// go
|
||||||
app.render(<QueryClientProvider client={query}>
|
app.render(<QueryClientProvider client={query}>
|
||||||
<ClientContext.Provider value={client}>
|
<ClientContext.Provider value={client}>
|
||||||
<StateContext.Provider value={state}>
|
<ConfigContext.Provider value={params}>
|
||||||
<OnnxWeb client={client} config={params} />
|
<StateContext.Provider value={state}>
|
||||||
</StateContext.Provider>
|
<OnnxWeb />
|
||||||
|
</StateContext.Provider>
|
||||||
|
</ConfigContext.Provider>
|
||||||
</ClientContext.Provider>
|
</ClientContext.Provider>
|
||||||
</QueryClientProvider>);
|
</QueryClientProvider>);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import {
|
||||||
BrushParams,
|
BrushParams,
|
||||||
Img2ImgParams,
|
Img2ImgParams,
|
||||||
InpaintParams,
|
InpaintParams,
|
||||||
|
ModelParams,
|
||||||
OutpaintPixels,
|
OutpaintPixels,
|
||||||
paramsFromConfig,
|
paramsFromConfig,
|
||||||
Txt2ImgParams,
|
Txt2ImgParams,
|
||||||
|
@ -75,12 +76,19 @@ interface UpscaleSlice {
|
||||||
setUpscale(upscale: Partial<UpscaleParams>): void;
|
setUpscale(upscale: Partial<UpscaleParams>): void;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface ModelSlice {
|
||||||
|
model: ModelParams;
|
||||||
|
|
||||||
|
setModel(model: Partial<ModelParams>): void;
|
||||||
|
}
|
||||||
|
|
||||||
export type OnnxState
|
export type OnnxState
|
||||||
= BrushSlice
|
= BrushSlice
|
||||||
& DefaultSlice
|
& DefaultSlice
|
||||||
& HistorySlice
|
& HistorySlice
|
||||||
& Img2ImgSlice
|
& Img2ImgSlice
|
||||||
& InpaintSlice
|
& InpaintSlice
|
||||||
|
& ModelSlice
|
||||||
& OutpaintSlice
|
& OutpaintSlice
|
||||||
& Txt2ImgSlice
|
& Txt2ImgSlice
|
||||||
& UpscaleSlice;
|
& UpscaleSlice;
|
||||||
|
@ -267,12 +275,30 @@ export function createStateSlices(base: ConfigParams) {
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const createModelSlice: StateCreator<OnnxState, [], [], ModelSlice> = (set) => ({
|
||||||
|
model: {
|
||||||
|
model: '',
|
||||||
|
platform: '',
|
||||||
|
upscaling: '',
|
||||||
|
correction: '',
|
||||||
|
},
|
||||||
|
setModel(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
model: {
|
||||||
|
...prev.model,
|
||||||
|
...params,
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
createBrushSlice,
|
createBrushSlice,
|
||||||
createDefaultSlice,
|
createDefaultSlice,
|
||||||
createHistorySlice,
|
createHistorySlice,
|
||||||
createImg2ImgSlice,
|
createImg2ImgSlice,
|
||||||
createInpaintSlice,
|
createInpaintSlice,
|
||||||
|
createModelSlice,
|
||||||
createOutpaintSlice,
|
createOutpaintSlice,
|
||||||
createTxt2ImgSlice,
|
createTxt2ImgSlice,
|
||||||
createUpscaleSlice,
|
createUpscaleSlice,
|
||||||
|
@ -280,4 +306,5 @@ export function createStateSlices(base: ConfigParams) {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
||||||
|
export const ConfigContext = createContext<Maybe<ConfigParams>>(undefined);
|
||||||
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
|
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
"ddpm",
|
"ddpm",
|
||||||
"denoise",
|
"denoise",
|
||||||
"directml",
|
"directml",
|
||||||
|
"ESRGAN",
|
||||||
"ftfy",
|
"ftfy",
|
||||||
"gfpgan",
|
"gfpgan",
|
||||||
"Heun",
|
"Heun",
|
||||||
|
@ -33,20 +34,24 @@
|
||||||
"numpy",
|
"numpy",
|
||||||
"Onnx",
|
"Onnx",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
|
"opset",
|
||||||
"outpaint",
|
"outpaint",
|
||||||
"outscale",
|
"outscale",
|
||||||
"pndm",
|
"pndm",
|
||||||
"pretrained",
|
"pretrained",
|
||||||
"protobuf",
|
"protobuf",
|
||||||
"resrgan",
|
"resrgan",
|
||||||
|
"RRDB",
|
||||||
"runwayml",
|
"runwayml",
|
||||||
"scandir",
|
"scandir",
|
||||||
"scipy",
|
"scipy",
|
||||||
"Singlestep",
|
"Singlestep",
|
||||||
"spacy",
|
"spacy",
|
||||||
"spinalcase",
|
"spinalcase",
|
||||||
|
"stabilityai",
|
||||||
"stringcase",
|
"stringcase",
|
||||||
"upsampler",
|
"upsampler",
|
||||||
|
"upscaling",
|
||||||
"venv",
|
"venv",
|
||||||
"virtualenv",
|
"virtualenv",
|
||||||
"zustand"
|
"zustand"
|
||||||
|
|
Loading…
Reference in New Issue