1
0
Fork 0

feat(gui): add menus for upscaling and correction models

This commit is contained in:
Sean Sube 2023-01-16 20:11:10 -06:00
parent ee6308a091
commit 0080d86d91
13 changed files with 247 additions and 163 deletions

View File

@ -2,22 +2,23 @@ import { doesExist } from '@apextoaster/js-utils';
import { ConfigParams } from './config.js'; import { ConfigParams } from './config.js';
export interface BaseImgParams { export interface ModelParams {
/** /**
* Which ONNX model to use. * Which ONNX model to use.
*/ */
model?: string; model: string;
/** /**
* Hardware accelerator or CPU mode. * Hardware accelerator or CPU mode.
*/ */
platform?: string; platform: string;
/** upscaling: string;
* Scheduling algorithm. correction: string;
*/ }
scheduler?: string;
export interface BaseImgParams {
scheduler: string;
prompt: string; prompt: string;
negativePrompt?: string; negativePrompt?: string;
@ -90,18 +91,24 @@ export interface ApiReady {
ready: boolean; ready: boolean;
} }
export interface ApiModels {
diffusion: Array<string>;
correction: Array<string>;
upscaling: Array<string>;
}
export interface ApiClient { export interface ApiClient {
masks(): Promise<Array<string>>; masks(): Promise<Array<string>>;
models(): Promise<Array<string>>; models(): Promise<ApiModels>;
noises(): Promise<Array<string>>; noises(): Promise<Array<string>>;
params(): Promise<ConfigParams>; params(): Promise<ConfigParams>;
platforms(): Promise<Array<string>>; platforms(): Promise<Array<string>>;
schedulers(): Promise<Array<string>>; schedulers(): Promise<Array<string>>;
img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>; img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>; txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
inpaint(params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>; inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
outpaint(params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>; outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
ready(params: ApiResponse): Promise<ApiReady>; ready(params: ApiResponse): Promise<ApiReady>;
} }
@ -111,9 +118,7 @@ export const STATUS_SUCCESS = 200;
export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams> { export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams> {
return { return {
cfg: defaults.cfg.default, cfg: defaults.cfg.default,
model: defaults.model.default,
negativePrompt: defaults.negativePrompt.default, negativePrompt: defaults.negativePrompt.default,
platform: defaults.platform.default,
prompt: defaults.prompt.default, prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default, scheduler: defaults.scheduler.default,
steps: defaults.steps.default, steps: defaults.steps.default,
@ -141,14 +146,6 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER)); url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
if (doesExist(params.model)) {
url.searchParams.append('model', params.model);
}
if (doesExist(params.platform)) {
url.searchParams.append('platform', params.platform);
}
if (doesExist(params.scheduler)) { if (doesExist(params.scheduler)) {
url.searchParams.append('scheduler', params.scheduler); url.searchParams.append('scheduler', params.scheduler);
} }
@ -167,6 +164,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
return url; return url;
} }
export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('model', params.model);
url.searchParams.append('platform', params.platform);
}
export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
if (upscale.enabled) { if (upscale.enabled) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT)); url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
@ -191,10 +193,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path); const res = await f(path);
return await res.json() as Array<string>; return await res.json() as Array<string>;
}, },
async models(): Promise<Array<string>> { async models(): Promise<ApiModels> {
const path = makeApiUrl(root, 'settings', 'models'); const path = makeApiUrl(root, 'settings', 'models');
const res = await f(path); const res = await f(path);
return await res.json() as Array<string>; return await res.json() as ApiModels;
}, },
async noises(): Promise<Array<string>> { async noises(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'noises'); const path = makeApiUrl(root, 'settings', 'noises');
@ -216,12 +218,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path); const res = await f(path);
return await res.json() as Array<string>; return await res.json() as Array<string>;
}, },
async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> { async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) { if (doesExist(pending)) {
return pending; return pending;
} }
const url = makeImageURL(root, 'img2img', params); const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
if (doesExist(upscale)) { if (doesExist(upscale)) {
@ -239,12 +243,13 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await // eslint-disable-next-line no-return-await
return await pending; return await pending;
}, },
async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> { async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
if (doesExist(pending)) { if (doesExist(pending)) {
return pending; return pending;
} }
const url = makeImageURL(root, 'txt2img', params); const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);
if (doesExist(params.width)) { if (doesExist(params.width)) {
url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER)); url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER));
@ -265,14 +270,17 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await // eslint-disable-next-line no-return-await
return await pending; return await pending;
}, },
async inpaint(params: InpaintParams, upscale?: UpscaleParams) { async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) { if (doesExist(pending)) {
return pending; return pending;
} }
const url = makeImageURL(root, 'inpaint', params); const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
url.searchParams.append('filter', params.filter); url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise); url.searchParams.append('noise', params.noise);
if (doesExist(upscale)) { if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale); appendUpscaleToURL(url, upscale);
} }
@ -289,12 +297,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await // eslint-disable-next-line no-return-await
return await pending; return await pending;
}, },
async outpaint(params: OutpaintParams, upscale?: UpscaleParams) { async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) {
if (doesExist(pending)) { if (doesExist(pending)) {
return pending; return pending;
} }
const url = makeImageURL(root, 'inpaint', params); const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
url.searchParams.append('filter', params.filter); url.searchParams.append('filter', params.filter);
url.searchParams.append('noise', params.noise); url.searchParams.append('noise', params.noise);

View File

@ -41,7 +41,9 @@ export function ImageControl(props: ImageControlProps) {
id='schedulers' id='schedulers'
labels={SCHEDULER_LABELS} labels={SCHEDULER_LABELS}
name='Scheduler' name='Scheduler'
result={schedulers} query={{
result: schedulers,
}}
value={mustDefault(params.scheduler, '')} value={mustDefault(params.scheduler, '')}
onChange={(value) => { onChange={(value) => {
if (doesExist(props.onChange)) { if (doesExist(props.onChange)) {

View File

@ -4,8 +4,8 @@ import * as React from 'react';
import { useMutation, useQueryClient } from 'react-query'; import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams, IMAGE_FILTER } from '../config.js'; import { IMAGE_FILTER } from '../config.js';
import { ClientContext, StateContext } from '../state.js'; import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js'; import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js'; import { ImageInput } from './ImageInput.js';
import { NumericField } from './NumericField.js'; import { NumericField } from './NumericField.js';
@ -13,23 +13,14 @@ import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React; const { useContext } = React;
export interface Img2ImgProps { export function Img2Img() {
config: ConfigParams; const config = mustExist(useContext(ConfigContext));
model: string;
platform: string;
}
export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props;
async function uploadSource() { async function uploadSource() {
const { img2img, upscale } = state.getState(); const { model, img2img, upscale } = state.getState();
const output = await client.img2img({ const output = await client.img2img(model, {
...img2img, ...img2img,
model,
platform,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, upscale); }, upscale);

View File

@ -4,8 +4,8 @@ import * as React from 'react';
import { useMutation, useQuery, useQueryClient } from 'react-query'; import { useMutation, useQuery, useQueryClient } from 'react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams, IMAGE_FILTER, STALE_TIME } from '../config.js'; import { IMAGE_FILTER, STALE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js'; import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { MASK_LABELS, NOISE_LABELS } from '../strings.js'; import { MASK_LABELS, NOISE_LABELS } from '../strings.js';
import { ImageControl } from './ImageControl.js'; import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js'; import { ImageInput } from './ImageInput.js';
@ -16,16 +16,10 @@ import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React; const { useContext } = React;
export interface InpaintProps { export function Inpaint() {
config: ConfigParams; const config = mustExist(useContext(ConfigContext));
model: string;
platform: string;
}
export function Inpaint(props: InpaintProps) {
const { config, model, platform } = props;
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
const masks = useQuery('masks', async () => client.masks(), { const masks = useQuery('masks', async () => client.masks(), {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
@ -35,24 +29,20 @@ export function Inpaint(props: InpaintProps) {
async function uploadSource(): Promise<void> { async function uploadSource(): Promise<void> {
// these are not watched by the component, only sent by the mutation // these are not watched by the component, only sent by the mutation
const { inpaint, outpaint, upscale } = state.getState(); const { model, inpaint, outpaint, upscale } = state.getState();
if (outpaint.enabled) { if (outpaint.enabled) {
const output = await client.outpaint({ const output = await client.outpaint(model, {
...inpaint, ...inpaint,
...outpaint, ...outpaint,
model,
platform,
mask: mustExist(mask), mask: mustExist(mask),
source: mustExist(source), source: mustExist(source),
}, upscale); }, upscale);
setLoading(output); setLoading(output);
} else { } else {
const output = await client.inpaint({ const output = await client.inpaint(model, {
...inpaint, ...inpaint,
model,
platform,
mask: mustExist(mask), mask: mustExist(mask),
source: mustExist(source), source: mustExist(source),
}, upscale); }, upscale);
@ -122,7 +112,9 @@ export function Inpaint(props: InpaintProps) {
id='masks' id='masks'
labels={MASK_LABELS} labels={MASK_LABELS}
name='Mask Filter' name='Mask Filter'
result={masks} query={{
result: masks,
}}
value={filter} value={filter}
onChange={(newFilter) => { onChange={(newFilter) => {
setInpaint({ setInpaint({
@ -134,7 +126,9 @@ export function Inpaint(props: InpaintProps) {
id='noises' id='noises'
labels={NOISE_LABELS} labels={NOISE_LABELS}
name='Noise Source' name='Noise Source'
result={noises} query={{
result: noises,
}}
value={noise} value={noise}
onChange={(newNoise) => { onChange={(newNoise) => {
setInpaint({ setInpaint({

View File

@ -0,0 +1,89 @@
import { mustExist } from '@apextoaster/js-utils';
import { Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useStore } from 'zustand';
import { STALE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
import { QueryList } from './QueryList.js';
export function ModelControl() {
const client = mustExist(useContext(ClientContext));
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.model);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setModel);
const models = useQuery('models', async () => client.models(), {
staleTime: STALE_TIME,
});
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});
return <Stack direction='row' spacing={2}>
<QueryList
id='platforms'
labels={PLATFORM_LABELS}
name='Platform'
query={{
result: platforms,
}}
value={params.platform}
onChange={(platform) => {
setModel({
platform,
});
}}
/>
<QueryList
id='diffusion'
labels={MODEL_LABELS}
name='Diffusion Model'
query={{
result: models,
selector: (result) => result.diffusion,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
<QueryList
id='upscaling'
labels={MODEL_LABELS}
name='Upscaling Model'
query={{
result: models,
selector: (result) => result.upscaling,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
<QueryList
id='correction'
labels={MODEL_LABELS}
name='Correction Model'
query={{
result: models,
selector: (result) => result.correction,
}}
value={params.model}
onChange={(model) => {
setModel({
model,
});
}}
/>
</Stack>;
}

View File

@ -1,41 +1,18 @@
import { mustExist } from '@apextoaster/js-utils';
import { TabContext, TabList, TabPanel } from '@mui/lab'; import { TabContext, TabList, TabPanel } from '@mui/lab';
import { Box, Container, Divider, Link, Stack, Tab, Typography } from '@mui/material'; import { Box, Container, Divider, Link, Tab, Typography } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useQuery } from 'react-query';
import { ApiClient } from '../client.js';
import { ConfigParams, STALE_TIME } from '../config.js';
import { ClientContext } from '../state.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
import { ImageHistory } from './ImageHistory.js'; import { ImageHistory } from './ImageHistory.js';
import { Img2Img } from './Img2Img.js'; import { Img2Img } from './Img2Img.js';
import { Inpaint } from './Inpaint.js'; import { Inpaint } from './Inpaint.js';
import { QueryList } from './QueryList.js'; import { ModelControl } from './ModelControl.js';
import { Settings } from './Settings.js'; import { Settings } from './Settings.js';
import { Txt2Img } from './Txt2Img.js'; import { Txt2Img } from './Txt2Img.js';
const { useContext, useState } = React; const { useState } = React;
export interface OnnxWebProps { export function OnnxWeb() {
client: ApiClient;
config: ConfigParams;
}
export function OnnxWeb(props: OnnxWebProps) {
const { config } = props;
const client = mustExist(useContext(ClientContext));
const [tab, setTab] = useState('txt2img'); const [tab, setTab] = useState('txt2img');
const [model, setModel] = useState(config.model.default);
const [platform, setPlatform] = useState(config.platform.default);
const models = useQuery('models', async () => client.models(), {
staleTime: STALE_TIME,
});
const platforms = useQuery('platforms', async () => client.platforms(), {
staleTime: STALE_TIME,
});
return ( return (
<Container> <Container>
@ -45,28 +22,7 @@ export function OnnxWeb(props: OnnxWebProps) {
</Typography> </Typography>
</Box> </Box>
<Box sx={{ mx: 4, my: 4 }}> <Box sx={{ mx: 4, my: 4 }}>
<Stack direction='row' spacing={2}> <ModelControl />
<QueryList
id='models'
labels={MODEL_LABELS}
name='Model'
result={models}
value={model}
onChange={(value) => {
setModel(value);
}}
/>
<QueryList
id='platforms'
labels={PLATFORM_LABELS}
name='Platform'
result={platforms}
value={platform}
onChange={(value) => {
setPlatform(value);
}}
/>
</Stack>
</Box> </Box>
<TabContext value={tab}> <TabContext value={tab}>
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}> <Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
@ -80,16 +36,16 @@ export function OnnxWeb(props: OnnxWebProps) {
</TabList> </TabList>
</Box> </Box>
<TabPanel value='txt2img'> <TabPanel value='txt2img'>
<Txt2Img config={config} model={model} platform={platform} /> <Txt2Img />
</TabPanel> </TabPanel>
<TabPanel value='img2img'> <TabPanel value='img2img'>
<Img2Img config={config} model={model} platform={platform} /> <Img2Img />
</TabPanel> </TabPanel>
<TabPanel value='inpaint'> <TabPanel value='inpaint'>
<Inpaint config={config} model={model} platform={platform} /> <Inpaint />
</TabPanel> </TabPanel>
<TabPanel value='settings'> <TabPanel value='settings'>
<Settings config={config} /> <Settings />
</TabPanel> </TabPanel>
</TabContext> </TabContext>
<Divider variant='middle' /> <Divider variant='middle' />

View File

@ -3,18 +3,42 @@ import { FormControl, InputLabel, MenuItem, Select } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { UseQueryResult } from 'react-query'; import { UseQueryResult } from 'react-query';
export interface QueryListProps { export interface QueryListComplete {
result: UseQueryResult<Array<string>>;
}
export interface QueryListFilter<T> {
result: UseQueryResult<T>;
selector: (result: T) => Array<string>;
}
export interface QueryListProps<T> {
id: string; id: string;
labels: Record<string, string>; labels: Record<string, string>;
name: string; name: string;
result: UseQueryResult<Array<string>>;
value: string; value: string;
query: QueryListComplete | QueryListFilter<T>;
onChange?: (value: string) => void; onChange?: (value: string) => void;
} }
export function QueryList(props: QueryListProps) { export function hasFilter<T>(query: QueryListComplete | QueryListFilter<T>): query is QueryListFilter<T> {
const { labels, result, value } = props; return Reflect.has(query, 'selector');
}
export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>): Array<string> {
if (hasFilter(query)) {
const data = mustExist(query.result.data);
return (query as QueryListFilter<unknown>).selector(data);
} else {
return mustExist(query.result.data);
}
}
export function QueryList<T>(props: QueryListProps<T>) {
const { labels, query, value } = props;
const { result } = query;
if (result.status === 'error') { if (result.status === 'error') {
if (result.error instanceof Error) { if (result.error instanceof Error) {
@ -34,7 +58,8 @@ export function QueryList(props: QueryListProps) {
// else: success // else: success
const labelID = `query-list-${props.id}-labels`; const labelID = `query-list-${props.id}-labels`;
const data = mustExist(result.data); const data = filterQuery(query);
return <FormControl> return <FormControl>
<InputLabel id={labelID}>{props.name}</InputLabel> <InputLabel id={labelID}>{props.name}</InputLabel>
<Select <Select

View File

@ -1,19 +1,13 @@
import { mustExist } from '@apextoaster/js-utils'; import { mustExist } from '@apextoaster/js-utils';
import { Button, Stack, TextField } from '@mui/material'; import { Button, Stack, TextField } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams } from '../config.js';
import { StateContext } from '../state.js'; import { StateContext } from '../state.js';
import { NumericField } from './NumericField.js'; import { NumericField } from './NumericField.js';
const { useContext } = React; export function Settings() {
export interface SettingsProps {
config: ConfigParams;
}
export function Settings(_props: SettingsProps) {
const state = useStore(mustExist(useContext(StateContext))); const state = useStore(mustExist(useContext(StateContext)));
return <Stack spacing={2}> return <Stack spacing={2}>
@ -25,16 +19,6 @@ export function Settings(_props: SettingsProps) {
value={state.limit} value={state.limit}
onChange={(value) => state.setLimit(value)} onChange={(value) => state.setLimit(value)}
/> />
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {
state.setDefaults({
model: event.target.value,
});
}} />
<TextField variant='outlined' label='Default Platform' value={state.defaults.platform} onChange={(event) => {
state.setDefaults({
platform: event.target.value,
});
}} />
<TextField variant='outlined' label='Default Prompt' value={state.defaults.prompt} onChange={(event) => { <TextField variant='outlined' label='Default Prompt' value={state.defaults.prompt} onChange={(event) => {
state.setDefaults({ state.setDefaults({
prompt: event.target.value, prompt: event.target.value,

View File

@ -4,31 +4,19 @@ import * as React from 'react';
import { useMutation, useQueryClient } from 'react-query'; import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams } from '../config.js'; import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { ClientContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js'; import { ImageControl } from './ImageControl.js';
import { NumericField } from './NumericField.js'; import { NumericField } from './NumericField.js';
import { UpscaleControl } from './UpscaleControl.js'; import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React; const { useContext } = React;
export interface Txt2ImgProps { export function Txt2Img() {
config: ConfigParams; const config = mustExist(useContext(ConfigContext));
model: string;
platform: string;
}
export function Txt2Img(props: Txt2ImgProps) {
const { config, model, platform } = props;
async function generateImage() { async function generateImage() {
const { txt2img, upscale } = state.getState(); const { model, txt2img, upscale } = state.getState();
const output = await client.txt2img({ const output = await client.txt2img(model, txt2img, upscale);
...txt2img,
model,
platform,
}, upscale);
setLoading(output); setLoading(output);
} }

View File

@ -1,6 +1,6 @@
import { Maybe } from '@apextoaster/js-utils'; import { Maybe } from '@apextoaster/js-utils';
import { Img2ImgParams, InpaintParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js'; import { Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js';
export interface ConfigNumber { export interface ConfigNumber {
default: number; default: number;
@ -30,7 +30,16 @@ export type ConfigState<T extends object, TValid = number | string> = {
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never; [K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
}; };
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams & InpaintParams & OutpaintParams & UpscaleParams>>; /* eslint-disable */
export type ConfigParams = ConfigRanges<Required<
Img2ImgParams &
Txt2ImgParams &
InpaintParams &
ModelParams &
OutpaintParams &
UpscaleParams
>>;
/* eslint-enable */
export interface Config { export interface Config {
api: { api: {

View File

@ -11,7 +11,7 @@ import { makeClient } from './client.js';
import { OnnxError } from './components/OnnxError.js'; import { OnnxError } from './components/OnnxError.js';
import { OnnxWeb } from './components/OnnxWeb.js'; import { OnnxWeb } from './components/OnnxWeb.js';
import { Config, loadConfig } from './config.js'; import { Config, loadConfig } from './config.js';
import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js'; import { ClientContext, ConfigContext, createStateSlices, OnnxState, StateContext } from './state.js';
export function getApiRoot(config: Config): string { export function getApiRoot(config: Config): string {
const query = new URLSearchParams(window.location.search); const query = new URLSearchParams(window.location.search);
@ -48,6 +48,7 @@ export async function main() {
createHistorySlice, createHistorySlice,
createImg2ImgSlice, createImg2ImgSlice,
createInpaintSlice, createInpaintSlice,
createModelSlice,
createOutpaintSlice, createOutpaintSlice,
createTxt2ImgSlice, createTxt2ImgSlice,
createUpscaleSlice, createUpscaleSlice,
@ -58,6 +59,7 @@ export async function main() {
...createHistorySlice(...slice), ...createHistorySlice(...slice),
...createImg2ImgSlice(...slice), ...createImg2ImgSlice(...slice),
...createInpaintSlice(...slice), ...createInpaintSlice(...slice),
...createModelSlice(...slice),
...createTxt2ImgSlice(...slice), ...createTxt2ImgSlice(...slice),
...createOutpaintSlice(...slice), ...createOutpaintSlice(...slice),
...createUpscaleSlice(...slice), ...createUpscaleSlice(...slice),
@ -87,9 +89,11 @@ export async function main() {
// go // go
app.render(<QueryClientProvider client={query}> app.render(<QueryClientProvider client={query}>
<ClientContext.Provider value={client}> <ClientContext.Provider value={client}>
<ConfigContext.Provider value={params}>
<StateContext.Provider value={state}> <StateContext.Provider value={state}>
<OnnxWeb client={client} config={params} /> <OnnxWeb />
</StateContext.Provider> </StateContext.Provider>
</ConfigContext.Provider>
</ClientContext.Provider> </ClientContext.Provider>
</QueryClientProvider>); </QueryClientProvider>);
} catch (err) { } catch (err) {

View File

@ -10,6 +10,7 @@ import {
BrushParams, BrushParams,
Img2ImgParams, Img2ImgParams,
InpaintParams, InpaintParams,
ModelParams,
OutpaintPixels, OutpaintPixels,
paramsFromConfig, paramsFromConfig,
Txt2ImgParams, Txt2ImgParams,
@ -75,12 +76,19 @@ interface UpscaleSlice {
setUpscale(upscale: Partial<UpscaleParams>): void; setUpscale(upscale: Partial<UpscaleParams>): void;
} }
interface ModelSlice {
model: ModelParams;
setModel(model: Partial<ModelParams>): void;
}
export type OnnxState export type OnnxState
= BrushSlice = BrushSlice
& DefaultSlice & DefaultSlice
& HistorySlice & HistorySlice
& Img2ImgSlice & Img2ImgSlice
& InpaintSlice & InpaintSlice
& ModelSlice
& OutpaintSlice & OutpaintSlice
& Txt2ImgSlice & Txt2ImgSlice
& UpscaleSlice; & UpscaleSlice;
@ -267,12 +275,30 @@ export function createStateSlices(base: ConfigParams) {
}, },
}); });
const createModelSlice: StateCreator<OnnxState, [], [], ModelSlice> = (set) => ({
model: {
model: '',
platform: '',
upscaling: '',
correction: '',
},
setModel(params) {
set((prev) => ({
model: {
...prev.model,
...params,
}
}));
},
});
return { return {
createBrushSlice, createBrushSlice,
createDefaultSlice, createDefaultSlice,
createHistorySlice, createHistorySlice,
createImg2ImgSlice, createImg2ImgSlice,
createInpaintSlice, createInpaintSlice,
createModelSlice,
createOutpaintSlice, createOutpaintSlice,
createTxt2ImgSlice, createTxt2ImgSlice,
createUpscaleSlice, createUpscaleSlice,
@ -280,4 +306,5 @@ export function createStateSlices(base: ConfigParams) {
} }
export const ClientContext = createContext<Maybe<ApiClient>>(undefined); export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
export const ConfigContext = createContext<Maybe<ConfigParams>>(undefined);
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined); export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);

View File

@ -19,6 +19,7 @@
"ddpm", "ddpm",
"denoise", "denoise",
"directml", "directml",
"ESRGAN",
"ftfy", "ftfy",
"gfpgan", "gfpgan",
"Heun", "Heun",
@ -33,20 +34,24 @@
"numpy", "numpy",
"Onnx", "Onnx",
"onnxruntime", "onnxruntime",
"opset",
"outpaint", "outpaint",
"outscale", "outscale",
"pndm", "pndm",
"pretrained", "pretrained",
"protobuf", "protobuf",
"resrgan", "resrgan",
"RRDB",
"runwayml", "runwayml",
"scandir", "scandir",
"scipy", "scipy",
"Singlestep", "Singlestep",
"spacy", "spacy",
"spinalcase", "spinalcase",
"stabilityai",
"stringcase", "stringcase",
"upsampler", "upsampler",
"upscaling",
"venv", "venv",
"virtualenv", "virtualenv",
"zustand" "zustand"