feat(gui): add menus for upscaling and correction models
This commit is contained in:
parent
ee6308a091
commit
0080d86d91
|
@ -2,22 +2,23 @@ import { doesExist } from '@apextoaster/js-utils';
|
|||
|
||||
import { ConfigParams } from './config.js';
|
||||
|
||||
export interface BaseImgParams {
|
||||
export interface ModelParams {
|
||||
/**
|
||||
* Which ONNX model to use.
|
||||
*/
|
||||
model?: string;
|
||||
model: string;
|
||||
|
||||
/**
|
||||
* Hardware accelerator or CPU mode.
|
||||
*/
|
||||
platform?: string;
|
||||
platform: string;
|
||||
|
||||
/**
|
||||
* Scheduling algorithm.
|
||||
*/
|
||||
scheduler?: string;
|
||||
upscaling: string;
|
||||
correction: string;
|
||||
}
|
||||
|
||||
export interface BaseImgParams {
|
||||
scheduler: string;
|
||||
prompt: string;
|
||||
negativePrompt?: string;
|
||||
|
||||
|
@ -90,18 +91,24 @@ export interface ApiReady {
|
|||
ready: boolean;
|
||||
}
|
||||
|
||||
export interface ApiModels {
|
||||
diffusion: Array<string>;
|
||||
correction: Array<string>;
|
||||
upscaling: Array<string>;
|
||||
}
|
||||
|
||||
export interface ApiClient {
|
||||
masks(): Promise<Array<string>>;
|
||||
models(): Promise<Array<string>>;
|
||||
models(): Promise<ApiModels>;
|
||||
noises(): Promise<Array<string>>;
|
||||
params(): Promise<ConfigParams>;
|
||||
platforms(): Promise<Array<string>>;
|
||||
schedulers(): Promise<Array<string>>;
|
||||
|
||||
img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
inpaint(params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
outpaint(params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
|
||||
ready(params: ApiResponse): Promise<ApiReady>;
|
||||
}
|
||||
|
@ -111,9 +118,7 @@ export const STATUS_SUCCESS = 200;
|
|||
export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams> {
|
||||
return {
|
||||
cfg: defaults.cfg.default,
|
||||
model: defaults.model.default,
|
||||
negativePrompt: defaults.negativePrompt.default,
|
||||
platform: defaults.platform.default,
|
||||
prompt: defaults.prompt.default,
|
||||
scheduler: defaults.scheduler.default,
|
||||
steps: defaults.steps.default,
|
||||
|
@ -141,14 +146,6 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
|
|||
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
|
||||
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));
|
||||
|
||||
if (doesExist(params.model)) {
|
||||
url.searchParams.append('model', params.model);
|
||||
}
|
||||
|
||||
if (doesExist(params.platform)) {
|
||||
url.searchParams.append('platform', params.platform);
|
||||
}
|
||||
|
||||
if (doesExist(params.scheduler)) {
|
||||
url.searchParams.append('scheduler', params.scheduler);
|
||||
}
|
||||
|
@ -167,6 +164,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
|
|||
return url;
|
||||
}
|
||||
|
||||
export function appendModelToURL(url: URL, params: ModelParams) {
|
||||
url.searchParams.append('model', params.model);
|
||||
url.searchParams.append('platform', params.platform);
|
||||
}
|
||||
|
||||
export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
|
||||
if (upscale.enabled) {
|
||||
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
|
||||
|
@ -191,10 +193,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async models(): Promise<Array<string>> {
|
||||
async models(): Promise<ApiModels> {
|
||||
const path = makeApiUrl(root, 'settings', 'models');
|
||||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
return await res.json() as ApiModels;
|
||||
},
|
||||
async noises(): Promise<Array<string>> {
|
||||
const path = makeApiUrl(root, 'settings', 'noises');
|
||||
|
@ -216,12 +218,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
||||
const url = makeImageURL(root, 'img2img', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
|
||||
|
||||
if (doesExist(upscale)) {
|
||||
|
@ -239,12 +243,13 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
// eslint-disable-next-line no-return-await
|
||||
return await pending;
|
||||
},
|
||||
async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
||||
const url = makeImageURL(root, 'txt2img', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
if (doesExist(params.width)) {
|
||||
url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER));
|
||||
|
@ -265,14 +270,17 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
// eslint-disable-next-line no-return-await
|
||||
return await pending;
|
||||
},
|
||||
async inpaint(params: InpaintParams, upscale?: UpscaleParams) {
|
||||
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
||||
const url = makeImageURL(root, 'inpaint', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
url.searchParams.append('filter', params.filter);
|
||||
url.searchParams.append('noise', params.noise);
|
||||
|
||||
if (doesExist(upscale)) {
|
||||
appendUpscaleToURL(url, upscale);
|
||||
}
|
||||
|
@ -289,12 +297,14 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
// eslint-disable-next-line no-return-await
|
||||
return await pending;
|
||||
},
|
||||
async outpaint(params: OutpaintParams, upscale?: UpscaleParams) {
|
||||
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
||||
const url = makeImageURL(root, 'inpaint', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
url.searchParams.append('filter', params.filter);
|
||||
url.searchParams.append('noise', params.noise);
|
||||
|
||||
|
|
|
@ -41,7 +41,9 @@ export function ImageControl(props: ImageControlProps) {
|
|||
id='schedulers'
|
||||
labels={SCHEDULER_LABELS}
|
||||
name='Scheduler'
|
||||
result={schedulers}
|
||||
query={{
|
||||
result: schedulers,
|
||||
}}
|
||||
value={mustDefault(params.scheduler, '')}
|
||||
onChange={(value) => {
|
||||
if (doesExist(props.onChange)) {
|
||||
|
|
|
@ -4,8 +4,8 @@ import * as React from 'react';
|
|||
import { useMutation, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams, IMAGE_FILTER } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../state.js';
|
||||
import { IMAGE_FILTER } from '../config.js';
|
||||
import { ClientContext, ConfigContext, StateContext } from '../state.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { ImageInput } from './ImageInput.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
@ -13,23 +13,14 @@ import { UpscaleControl } from './UpscaleControl.js';
|
|||
|
||||
const { useContext } = React;
|
||||
|
||||
export interface Img2ImgProps {
|
||||
config: ConfigParams;
|
||||
|
||||
model: string;
|
||||
platform: string;
|
||||
}
|
||||
|
||||
export function Img2Img(props: Img2ImgProps) {
|
||||
const { config, model, platform } = props;
|
||||
export function Img2Img() {
|
||||
const config = mustExist(useContext(ConfigContext));
|
||||
|
||||
async function uploadSource() {
|
||||
const { img2img, upscale } = state.getState();
|
||||
const { model, img2img, upscale } = state.getState();
|
||||
|
||||
const output = await client.img2img({
|
||||
const output = await client.img2img(model, {
|
||||
...img2img,
|
||||
model,
|
||||
platform,
|
||||
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
||||
}, upscale);
|
||||
|
||||
|
|
|
@ -4,8 +4,8 @@ import * as React from 'react';
|
|||
import { useMutation, useQuery, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams, IMAGE_FILTER, STALE_TIME } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../state.js';
|
||||
import { IMAGE_FILTER, STALE_TIME } from '../config.js';
|
||||
import { ClientContext, ConfigContext, StateContext } from '../state.js';
|
||||
import { MASK_LABELS, NOISE_LABELS } from '../strings.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { ImageInput } from './ImageInput.js';
|
||||
|
@ -16,16 +16,10 @@ import { UpscaleControl } from './UpscaleControl.js';
|
|||
|
||||
const { useContext } = React;
|
||||
|
||||
export interface InpaintProps {
|
||||
config: ConfigParams;
|
||||
|
||||
model: string;
|
||||
platform: string;
|
||||
}
|
||||
|
||||
export function Inpaint(props: InpaintProps) {
|
||||
const { config, model, platform } = props;
|
||||
export function Inpaint() {
|
||||
const config = mustExist(useContext(ConfigContext));
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
||||
const masks = useQuery('masks', async () => client.masks(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
@ -35,24 +29,20 @@ export function Inpaint(props: InpaintProps) {
|
|||
|
||||
async function uploadSource(): Promise<void> {
|
||||
// these are not watched by the component, only sent by the mutation
|
||||
const { inpaint, outpaint, upscale } = state.getState();
|
||||
const { model, inpaint, outpaint, upscale } = state.getState();
|
||||
|
||||
if (outpaint.enabled) {
|
||||
const output = await client.outpaint({
|
||||
const output = await client.outpaint(model, {
|
||||
...inpaint,
|
||||
...outpaint,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(mask),
|
||||
source: mustExist(source),
|
||||
}, upscale);
|
||||
|
||||
setLoading(output);
|
||||
} else {
|
||||
const output = await client.inpaint({
|
||||
const output = await client.inpaint(model, {
|
||||
...inpaint,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(mask),
|
||||
source: mustExist(source),
|
||||
}, upscale);
|
||||
|
@ -122,7 +112,9 @@ export function Inpaint(props: InpaintProps) {
|
|||
id='masks'
|
||||
labels={MASK_LABELS}
|
||||
name='Mask Filter'
|
||||
result={masks}
|
||||
query={{
|
||||
result: masks,
|
||||
}}
|
||||
value={filter}
|
||||
onChange={(newFilter) => {
|
||||
setInpaint({
|
||||
|
@ -134,7 +126,9 @@ export function Inpaint(props: InpaintProps) {
|
|||
id='noises'
|
||||
labels={NOISE_LABELS}
|
||||
name='Noise Source'
|
||||
result={noises}
|
||||
query={{
|
||||
result: noises,
|
||||
}}
|
||||
value={noise}
|
||||
onChange={(newNoise) => {
|
||||
setInpaint({
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useQuery } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { STALE_TIME } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../state.js';
|
||||
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
|
||||
import { QueryList } from './QueryList.js';
|
||||
|
||||
export function ModelControl() {
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const params = useStore(state, (s) => s.model);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setModel = useStore(state, (s) => s.setModel);
|
||||
|
||||
const models = useQuery('models', async () => client.models(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const platforms = useQuery('platforms', async () => client.platforms(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
return <Stack direction='row' spacing={2}>
|
||||
<QueryList
|
||||
id='platforms'
|
||||
labels={PLATFORM_LABELS}
|
||||
name='Platform'
|
||||
query={{
|
||||
result: platforms,
|
||||
}}
|
||||
value={params.platform}
|
||||
onChange={(platform) => {
|
||||
setModel({
|
||||
platform,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='diffusion'
|
||||
labels={MODEL_LABELS}
|
||||
name='Diffusion Model'
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.diffusion,
|
||||
}}
|
||||
value={params.model}
|
||||
onChange={(model) => {
|
||||
setModel({
|
||||
model,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='upscaling'
|
||||
labels={MODEL_LABELS}
|
||||
name='Upscaling Model'
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.upscaling,
|
||||
}}
|
||||
value={params.model}
|
||||
onChange={(model) => {
|
||||
setModel({
|
||||
model,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='correction'
|
||||
labels={MODEL_LABELS}
|
||||
name='Correction Model'
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.correction,
|
||||
}}
|
||||
value={params.model}
|
||||
onChange={(model) => {
|
||||
setModel({
|
||||
model,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
|
||||
</Stack>;
|
||||
}
|
|
@ -1,41 +1,18 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
||||
import { Box, Container, Divider, Link, Stack, Tab, Typography } from '@mui/material';
|
||||
import { Box, Container, Divider, Link, Tab, Typography } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useQuery } from 'react-query';
|
||||
|
||||
import { ApiClient } from '../client.js';
|
||||
import { ConfigParams, STALE_TIME } from '../config.js';
|
||||
import { ClientContext } from '../state.js';
|
||||
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
|
||||
import { ImageHistory } from './ImageHistory.js';
|
||||
import { Img2Img } from './Img2Img.js';
|
||||
import { Inpaint } from './Inpaint.js';
|
||||
import { QueryList } from './QueryList.js';
|
||||
import { ModelControl } from './ModelControl.js';
|
||||
import { Settings } from './Settings.js';
|
||||
import { Txt2Img } from './Txt2Img.js';
|
||||
|
||||
const { useContext, useState } = React;
|
||||
const { useState } = React;
|
||||
|
||||
export interface OnnxWebProps {
|
||||
client: ApiClient;
|
||||
config: ConfigParams;
|
||||
}
|
||||
|
||||
export function OnnxWeb(props: OnnxWebProps) {
|
||||
const { config } = props;
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
export function OnnxWeb() {
|
||||
const [tab, setTab] = useState('txt2img');
|
||||
const [model, setModel] = useState(config.model.default);
|
||||
const [platform, setPlatform] = useState(config.platform.default);
|
||||
|
||||
const models = useQuery('models', async () => client.models(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const platforms = useQuery('platforms', async () => client.platforms(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
return (
|
||||
<Container>
|
||||
|
@ -45,28 +22,7 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
</Typography>
|
||||
</Box>
|
||||
<Box sx={{ mx: 4, my: 4 }}>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<QueryList
|
||||
id='models'
|
||||
labels={MODEL_LABELS}
|
||||
name='Model'
|
||||
result={models}
|
||||
value={model}
|
||||
onChange={(value) => {
|
||||
setModel(value);
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='platforms'
|
||||
labels={PLATFORM_LABELS}
|
||||
name='Platform'
|
||||
result={platforms}
|
||||
value={platform}
|
||||
onChange={(value) => {
|
||||
setPlatform(value);
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<ModelControl />
|
||||
</Box>
|
||||
<TabContext value={tab}>
|
||||
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
|
||||
|
@ -80,16 +36,16 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
</TabList>
|
||||
</Box>
|
||||
<TabPanel value='txt2img'>
|
||||
<Txt2Img config={config} model={model} platform={platform} />
|
||||
<Txt2Img />
|
||||
</TabPanel>
|
||||
<TabPanel value='img2img'>
|
||||
<Img2Img config={config} model={model} platform={platform} />
|
||||
<Img2Img />
|
||||
</TabPanel>
|
||||
<TabPanel value='inpaint'>
|
||||
<Inpaint config={config} model={model} platform={platform} />
|
||||
<Inpaint />
|
||||
</TabPanel>
|
||||
<TabPanel value='settings'>
|
||||
<Settings config={config} />
|
||||
<Settings />
|
||||
</TabPanel>
|
||||
</TabContext>
|
||||
<Divider variant='middle' />
|
||||
|
|
|
@ -3,18 +3,42 @@ import { FormControl, InputLabel, MenuItem, Select } from '@mui/material';
|
|||
import * as React from 'react';
|
||||
import { UseQueryResult } from 'react-query';
|
||||
|
||||
export interface QueryListProps {
|
||||
export interface QueryListComplete {
|
||||
result: UseQueryResult<Array<string>>;
|
||||
}
|
||||
|
||||
export interface QueryListFilter<T> {
|
||||
result: UseQueryResult<T>;
|
||||
selector: (result: T) => Array<string>;
|
||||
}
|
||||
|
||||
export interface QueryListProps<T> {
|
||||
id: string;
|
||||
labels: Record<string, string>;
|
||||
name: string;
|
||||
result: UseQueryResult<Array<string>>;
|
||||
value: string;
|
||||
|
||||
query: QueryListComplete | QueryListFilter<T>;
|
||||
|
||||
onChange?: (value: string) => void;
|
||||
}
|
||||
|
||||
export function QueryList(props: QueryListProps) {
|
||||
const { labels, result, value } = props;
|
||||
export function hasFilter<T>(query: QueryListComplete | QueryListFilter<T>): query is QueryListFilter<T> {
|
||||
return Reflect.has(query, 'selector');
|
||||
}
|
||||
|
||||
export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>): Array<string> {
|
||||
if (hasFilter(query)) {
|
||||
const data = mustExist(query.result.data);
|
||||
return (query as QueryListFilter<unknown>).selector(data);
|
||||
} else {
|
||||
return mustExist(query.result.data);
|
||||
}
|
||||
}
|
||||
|
||||
export function QueryList<T>(props: QueryListProps<T>) {
|
||||
const { labels, query, value } = props;
|
||||
const { result } = query;
|
||||
|
||||
if (result.status === 'error') {
|
||||
if (result.error instanceof Error) {
|
||||
|
@ -34,7 +58,8 @@ export function QueryList(props: QueryListProps) {
|
|||
|
||||
// else: success
|
||||
const labelID = `query-list-${props.id}-labels`;
|
||||
const data = mustExist(result.data);
|
||||
const data = filterQuery(query);
|
||||
|
||||
return <FormControl>
|
||||
<InputLabel id={labelID}>{props.name}</InputLabel>
|
||||
<Select
|
||||
|
|
|
@ -1,19 +1,13 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Button, Stack, TextField } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams } from '../config.js';
|
||||
import { StateContext } from '../state.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
||||
const { useContext } = React;
|
||||
|
||||
export interface SettingsProps {
|
||||
config: ConfigParams;
|
||||
}
|
||||
|
||||
export function Settings(_props: SettingsProps) {
|
||||
export function Settings() {
|
||||
const state = useStore(mustExist(useContext(StateContext)));
|
||||
|
||||
return <Stack spacing={2}>
|
||||
|
@ -25,16 +19,6 @@ export function Settings(_props: SettingsProps) {
|
|||
value={state.limit}
|
||||
onChange={(value) => state.setLimit(value)}
|
||||
/>
|
||||
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {
|
||||
state.setDefaults({
|
||||
model: event.target.value,
|
||||
});
|
||||
}} />
|
||||
<TextField variant='outlined' label='Default Platform' value={state.defaults.platform} onChange={(event) => {
|
||||
state.setDefaults({
|
||||
platform: event.target.value,
|
||||
});
|
||||
}} />
|
||||
<TextField variant='outlined' label='Default Prompt' value={state.defaults.prompt} onChange={(event) => {
|
||||
state.setDefaults({
|
||||
prompt: event.target.value,
|
||||
|
|
|
@ -4,31 +4,19 @@ import * as React from 'react';
|
|||
import { useMutation, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../state.js';
|
||||
import { ClientContext, ConfigContext, StateContext } from '../state.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
import { UpscaleControl } from './UpscaleControl.js';
|
||||
|
||||
const { useContext } = React;
|
||||
|
||||
export interface Txt2ImgProps {
|
||||
config: ConfigParams;
|
||||
|
||||
model: string;
|
||||
platform: string;
|
||||
}
|
||||
|
||||
export function Txt2Img(props: Txt2ImgProps) {
|
||||
const { config, model, platform } = props;
|
||||
export function Txt2Img() {
|
||||
const config = mustExist(useContext(ConfigContext));
|
||||
|
||||
async function generateImage() {
|
||||
const { txt2img, upscale } = state.getState();
|
||||
const output = await client.txt2img({
|
||||
...txt2img,
|
||||
model,
|
||||
platform,
|
||||
}, upscale);
|
||||
const { model, txt2img, upscale } = state.getState();
|
||||
const output = await client.txt2img(model, txt2img, upscale);
|
||||
|
||||
setLoading(output);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import { Maybe } from '@apextoaster/js-utils';
|
||||
|
||||
import { Img2ImgParams, InpaintParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js';
|
||||
import { Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './client.js';
|
||||
|
||||
export interface ConfigNumber {
|
||||
default: number;
|
||||
|
@ -30,7 +30,16 @@ export type ConfigState<T extends object, TValid = number | string> = {
|
|||
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
|
||||
};
|
||||
|
||||
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams & InpaintParams & OutpaintParams & UpscaleParams>>;
|
||||
/* eslint-disable */
|
||||
export type ConfigParams = ConfigRanges<Required<
|
||||
Img2ImgParams &
|
||||
Txt2ImgParams &
|
||||
InpaintParams &
|
||||
ModelParams &
|
||||
OutpaintParams &
|
||||
UpscaleParams
|
||||
>>;
|
||||
/* eslint-enable */
|
||||
|
||||
export interface Config {
|
||||
api: {
|
||||
|
|
|
@ -11,7 +11,7 @@ import { makeClient } from './client.js';
|
|||
import { OnnxError } from './components/OnnxError.js';
|
||||
import { OnnxWeb } from './components/OnnxWeb.js';
|
||||
import { Config, loadConfig } from './config.js';
|
||||
import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js';
|
||||
import { ClientContext, ConfigContext, createStateSlices, OnnxState, StateContext } from './state.js';
|
||||
|
||||
export function getApiRoot(config: Config): string {
|
||||
const query = new URLSearchParams(window.location.search);
|
||||
|
@ -48,6 +48,7 @@ export async function main() {
|
|||
createHistorySlice,
|
||||
createImg2ImgSlice,
|
||||
createInpaintSlice,
|
||||
createModelSlice,
|
||||
createOutpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
createUpscaleSlice,
|
||||
|
@ -58,6 +59,7 @@ export async function main() {
|
|||
...createHistorySlice(...slice),
|
||||
...createImg2ImgSlice(...slice),
|
||||
...createInpaintSlice(...slice),
|
||||
...createModelSlice(...slice),
|
||||
...createTxt2ImgSlice(...slice),
|
||||
...createOutpaintSlice(...slice),
|
||||
...createUpscaleSlice(...slice),
|
||||
|
@ -87,9 +89,11 @@ export async function main() {
|
|||
// go
|
||||
app.render(<QueryClientProvider client={query}>
|
||||
<ClientContext.Provider value={client}>
|
||||
<StateContext.Provider value={state}>
|
||||
<OnnxWeb client={client} config={params} />
|
||||
</StateContext.Provider>
|
||||
<ConfigContext.Provider value={params}>
|
||||
<StateContext.Provider value={state}>
|
||||
<OnnxWeb />
|
||||
</StateContext.Provider>
|
||||
</ConfigContext.Provider>
|
||||
</ClientContext.Provider>
|
||||
</QueryClientProvider>);
|
||||
} catch (err) {
|
||||
|
|
|
@ -10,6 +10,7 @@ import {
|
|||
BrushParams,
|
||||
Img2ImgParams,
|
||||
InpaintParams,
|
||||
ModelParams,
|
||||
OutpaintPixels,
|
||||
paramsFromConfig,
|
||||
Txt2ImgParams,
|
||||
|
@ -75,12 +76,19 @@ interface UpscaleSlice {
|
|||
setUpscale(upscale: Partial<UpscaleParams>): void;
|
||||
}
|
||||
|
||||
interface ModelSlice {
|
||||
model: ModelParams;
|
||||
|
||||
setModel(model: Partial<ModelParams>): void;
|
||||
}
|
||||
|
||||
export type OnnxState
|
||||
= BrushSlice
|
||||
& DefaultSlice
|
||||
& HistorySlice
|
||||
& Img2ImgSlice
|
||||
& InpaintSlice
|
||||
& ModelSlice
|
||||
& OutpaintSlice
|
||||
& Txt2ImgSlice
|
||||
& UpscaleSlice;
|
||||
|
@ -267,12 +275,30 @@ export function createStateSlices(base: ConfigParams) {
|
|||
},
|
||||
});
|
||||
|
||||
const createModelSlice: StateCreator<OnnxState, [], [], ModelSlice> = (set) => ({
|
||||
model: {
|
||||
model: '',
|
||||
platform: '',
|
||||
upscaling: '',
|
||||
correction: '',
|
||||
},
|
||||
setModel(params) {
|
||||
set((prev) => ({
|
||||
model: {
|
||||
...prev.model,
|
||||
...params,
|
||||
}
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
createBrushSlice,
|
||||
createDefaultSlice,
|
||||
createHistorySlice,
|
||||
createImg2ImgSlice,
|
||||
createInpaintSlice,
|
||||
createModelSlice,
|
||||
createOutpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
createUpscaleSlice,
|
||||
|
@ -280,4 +306,5 @@ export function createStateSlices(base: ConfigParams) {
|
|||
}
|
||||
|
||||
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
||||
export const ConfigContext = createContext<Maybe<ConfigParams>>(undefined);
|
||||
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
"ddpm",
|
||||
"denoise",
|
||||
"directml",
|
||||
"ESRGAN",
|
||||
"ftfy",
|
||||
"gfpgan",
|
||||
"Heun",
|
||||
|
@ -33,20 +34,24 @@
|
|||
"numpy",
|
||||
"Onnx",
|
||||
"onnxruntime",
|
||||
"opset",
|
||||
"outpaint",
|
||||
"outscale",
|
||||
"pndm",
|
||||
"pretrained",
|
||||
"protobuf",
|
||||
"resrgan",
|
||||
"RRDB",
|
||||
"runwayml",
|
||||
"scandir",
|
||||
"scipy",
|
||||
"Singlestep",
|
||||
"spacy",
|
||||
"spinalcase",
|
||||
"stabilityai",
|
||||
"stringcase",
|
||||
"upsampler",
|
||||
"upscaling",
|
||||
"venv",
|
||||
"virtualenv",
|
||||
"zustand"
|
||||
|
|
Loading…
Reference in New Issue