From 0080d86d91fbb476684a0e8614a3950c25cede06 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 16 Jan 2023 20:11:10 -0600 Subject: [PATCH] feat(gui): add menus for upscaling and correction models --- gui/src/client.ts | 66 ++++++++++++--------- gui/src/components/ImageControl.tsx | 4 +- gui/src/components/Img2Img.tsx | 21 ++----- gui/src/components/Inpaint.tsx | 34 +++++------ gui/src/components/ModelControl.tsx | 89 +++++++++++++++++++++++++++++ gui/src/components/OnnxWeb.tsx | 62 +++----------------- gui/src/components/QueryList.tsx | 35 ++++++++++-- gui/src/components/Settings.tsx | 20 +------ gui/src/components/Txt2Img.tsx | 22 ++----- gui/src/config.ts | 13 ++++- gui/src/main.tsx | 12 ++-- gui/src/state.ts | 27 +++++++++ onnx-web.code-workspace | 5 ++ 13 files changed, 247 insertions(+), 163 deletions(-) create mode 100644 gui/src/components/ModelControl.tsx diff --git a/gui/src/client.ts b/gui/src/client.ts index d298a8f9..0db8e653 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -2,22 +2,23 @@ import { doesExist } from '@apextoaster/js-utils'; import { ConfigParams } from './config.js'; -export interface BaseImgParams { +export interface ModelParams { /** * Which ONNX model to use. */ - model?: string; + model: string; /** * Hardware accelerator or CPU mode. */ - platform?: string; + platform: string; - /** - * Scheduling algorithm. - */ - scheduler?: string; + upscaling: string; + correction: string; +} +export interface BaseImgParams { + scheduler: string; prompt: string; negativePrompt?: string; @@ -90,18 +91,24 @@ export interface ApiReady { ready: boolean; } +export interface ApiModels { + diffusion: Array; + correction: Array; + upscaling: Array; +} + export interface ApiClient { masks(): Promise>; - models(): Promise>; + models(): Promise; noises(): Promise>; params(): Promise; platforms(): Promise>; schedulers(): Promise>; - img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise; - txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise; - inpaint(params: InpaintParams, upscale?: UpscaleParams): Promise; - outpaint(params: OutpaintParams, upscale?: UpscaleParams): Promise; + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise; + txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise; + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise; + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise; ready(params: ApiResponse): Promise; } @@ -111,9 +118,7 @@ export const STATUS_SUCCESS = 200; export function paramsFromConfig(defaults: ConfigParams): Required { return { cfg: defaults.cfg.default, - model: defaults.model.default, negativePrompt: defaults.negativePrompt.default, - platform: defaults.platform.default, prompt: defaults.prompt.default, scheduler: defaults.scheduler.default, steps: defaults.steps.default, @@ -141,14 +146,6 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER)); - if (doesExist(params.model)) { - url.searchParams.append('model', params.model); - } - - if (doesExist(params.platform)) { - url.searchParams.append('platform', params.platform); - } - if (doesExist(params.scheduler)) { url.searchParams.append('scheduler', params.scheduler); } @@ -167,6 +164,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): return url; } +export function appendModelToURL(url: URL, params: ModelParams) { + url.searchParams.append('model', params.model); + url.searchParams.append('platform', params.platform); +} + export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { if (upscale.enabled) { url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT)); @@ -191,10 +193,10 @@ export function makeClient(root: string, f = fetch): ApiClient { const res = await f(path); return await res.json() as Array; }, - async models(): Promise> { + async models(): Promise { const path = makeApiUrl(root, 'settings', 'models'); const res = await f(path); - return await res.json() as Array; + return await res.json() as ApiModels; }, async noises(): Promise> { const path = makeApiUrl(root, 'settings', 'noises'); @@ -216,12 +218,14 @@ export function makeClient(root: string, f = fetch): ApiClient { const res = await f(path); return await res.json() as Array; }, - async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise { + async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise { if (doesExist(pending)) { return pending; } const url = makeImageURL(root, 'img2img', params); + appendModelToURL(url, model); + url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); if (doesExist(upscale)) { @@ -239,12 +243,13 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise { + async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise { if (doesExist(pending)) { return pending; } const url = makeImageURL(root, 'txt2img', params); + appendModelToURL(url, model); if (doesExist(params.width)) { url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER)); @@ -265,14 +270,17 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async inpaint(params: InpaintParams, upscale?: UpscaleParams) { + async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) { if (doesExist(pending)) { return pending; } const url = makeImageURL(root, 'inpaint', params); + appendModelToURL(url, model); + url.searchParams.append('filter', params.filter); url.searchParams.append('noise', params.noise); + if (doesExist(upscale)) { appendUpscaleToURL(url, upscale); } @@ -289,12 +297,14 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async outpaint(params: OutpaintParams, upscale?: UpscaleParams) { + async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) { if (doesExist(pending)) { return pending; } const url = makeImageURL(root, 'inpaint', params); + appendModelToURL(url, model); + url.searchParams.append('filter', params.filter); url.searchParams.append('noise', params.noise); diff --git a/gui/src/components/ImageControl.tsx b/gui/src/components/ImageControl.tsx index 4506518b..08625f14 100644 --- a/gui/src/components/ImageControl.tsx +++ b/gui/src/components/ImageControl.tsx @@ -41,7 +41,9 @@ export function ImageControl(props: ImageControlProps) { id='schedulers' labels={SCHEDULER_LABELS} name='Scheduler' - result={schedulers} + query={{ + result: schedulers, + }} value={mustDefault(params.scheduler, '')} onChange={(value) => { if (doesExist(props.onChange)) { diff --git a/gui/src/components/Img2Img.tsx b/gui/src/components/Img2Img.tsx index aa67f5e4..b8a8c309 100644 --- a/gui/src/components/Img2Img.tsx +++ b/gui/src/components/Img2Img.tsx @@ -4,8 +4,8 @@ import * as React from 'react'; import { useMutation, useQueryClient } from 'react-query'; import { useStore } from 'zustand'; -import { ConfigParams, IMAGE_FILTER } from '../config.js'; -import { ClientContext, StateContext } from '../state.js'; +import { IMAGE_FILTER } from '../config.js'; +import { ClientContext, ConfigContext, StateContext } from '../state.js'; import { ImageControl } from './ImageControl.js'; import { ImageInput } from './ImageInput.js'; import { NumericField } from './NumericField.js'; @@ -13,23 +13,14 @@ import { UpscaleControl } from './UpscaleControl.js'; const { useContext } = React; -export interface Img2ImgProps { - config: ConfigParams; - - model: string; - platform: string; -} - -export function Img2Img(props: Img2ImgProps) { - const { config, model, platform } = props; +export function Img2Img() { + const config = mustExist(useContext(ConfigContext)); async function uploadSource() { - const { img2img, upscale } = state.getState(); + const { model, img2img, upscale } = state.getState(); - const output = await client.img2img({ + const output = await client.img2img(model, { ...img2img, - model, - platform, source: mustExist(img2img.source), // TODO: show an error if this doesn't exist }, upscale); diff --git a/gui/src/components/Inpaint.tsx b/gui/src/components/Inpaint.tsx index 3c0966fc..3cc20d26 100644 --- a/gui/src/components/Inpaint.tsx +++ b/gui/src/components/Inpaint.tsx @@ -4,8 +4,8 @@ import * as React from 'react'; import { useMutation, useQuery, useQueryClient } from 'react-query'; import { useStore } from 'zustand'; -import { ConfigParams, IMAGE_FILTER, STALE_TIME } from '../config.js'; -import { ClientContext, StateContext } from '../state.js'; +import { IMAGE_FILTER, STALE_TIME } from '../config.js'; +import { ClientContext, ConfigContext, StateContext } from '../state.js'; import { MASK_LABELS, NOISE_LABELS } from '../strings.js'; import { ImageControl } from './ImageControl.js'; import { ImageInput } from './ImageInput.js'; @@ -16,16 +16,10 @@ import { UpscaleControl } from './UpscaleControl.js'; const { useContext } = React; -export interface InpaintProps { - config: ConfigParams; - - model: string; - platform: string; -} - -export function Inpaint(props: InpaintProps) { - const { config, model, platform } = props; +export function Inpaint() { + const config = mustExist(useContext(ConfigContext)); const client = mustExist(useContext(ClientContext)); + const masks = useQuery('masks', async () => client.masks(), { staleTime: STALE_TIME, }); @@ -35,24 +29,20 @@ export function Inpaint(props: InpaintProps) { async function uploadSource(): Promise { // these are not watched by the component, only sent by the mutation - const { inpaint, outpaint, upscale } = state.getState(); + const { model, inpaint, outpaint, upscale } = state.getState(); if (outpaint.enabled) { - const output = await client.outpaint({ + const output = await client.outpaint(model, { ...inpaint, ...outpaint, - model, - platform, mask: mustExist(mask), source: mustExist(source), }, upscale); setLoading(output); } else { - const output = await client.inpaint({ + const output = await client.inpaint(model, { ...inpaint, - model, - platform, mask: mustExist(mask), source: mustExist(source), }, upscale); @@ -122,7 +112,9 @@ export function Inpaint(props: InpaintProps) { id='masks' labels={MASK_LABELS} name='Mask Filter' - result={masks} + query={{ + result: masks, + }} value={filter} onChange={(newFilter) => { setInpaint({ @@ -134,7 +126,9 @@ export function Inpaint(props: InpaintProps) { id='noises' labels={NOISE_LABELS} name='Noise Source' - result={noises} + query={{ + result: noises, + }} value={noise} onChange={(newNoise) => { setInpaint({ diff --git a/gui/src/components/ModelControl.tsx b/gui/src/components/ModelControl.tsx new file mode 100644 index 00000000..ef1b2f4b --- /dev/null +++ b/gui/src/components/ModelControl.tsx @@ -0,0 +1,89 @@ +import { mustExist } from '@apextoaster/js-utils'; +import { Stack } from '@mui/material'; +import * as React from 'react'; +import { useContext } from 'react'; +import { useQuery } from 'react-query'; +import { useStore } from 'zustand'; + +import { STALE_TIME } from '../config.js'; +import { ClientContext, StateContext } from '../state.js'; +import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js'; +import { QueryList } from './QueryList.js'; + +export function ModelControl() { + const client = mustExist(useContext(ClientContext)); + const state = mustExist(useContext(StateContext)); + const params = useStore(state, (s) => s.model); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setModel = useStore(state, (s) => s.setModel); + + const models = useQuery('models', async () => client.models(), { + staleTime: STALE_TIME, + }); + const platforms = useQuery('platforms', async () => client.platforms(), { + staleTime: STALE_TIME, + }); + + return + { + setModel({ + platform, + }); + }} + /> + result.diffusion, + }} + value={params.model} + onChange={(model) => { + setModel({ + model, + }); + }} + /> + result.upscaling, + }} + value={params.model} + onChange={(model) => { + setModel({ + model, + }); + }} + /> + result.correction, + }} + value={params.model} + onChange={(model) => { + setModel({ + model, + }); + }} + /> + + ; +} diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index ffeced25..b71ccbb4 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -1,41 +1,18 @@ -import { mustExist } from '@apextoaster/js-utils'; import { TabContext, TabList, TabPanel } from '@mui/lab'; -import { Box, Container, Divider, Link, Stack, Tab, Typography } from '@mui/material'; +import { Box, Container, Divider, Link, Tab, Typography } from '@mui/material'; import * as React from 'react'; -import { useQuery } from 'react-query'; -import { ApiClient } from '../client.js'; -import { ConfigParams, STALE_TIME } from '../config.js'; -import { ClientContext } from '../state.js'; -import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js'; import { ImageHistory } from './ImageHistory.js'; import { Img2Img } from './Img2Img.js'; import { Inpaint } from './Inpaint.js'; -import { QueryList } from './QueryList.js'; +import { ModelControl } from './ModelControl.js'; import { Settings } from './Settings.js'; import { Txt2Img } from './Txt2Img.js'; -const { useContext, useState } = React; +const { useState } = React; -export interface OnnxWebProps { - client: ApiClient; - config: ConfigParams; -} - -export function OnnxWeb(props: OnnxWebProps) { - const { config } = props; - - const client = mustExist(useContext(ClientContext)); +export function OnnxWeb() { const [tab, setTab] = useState('txt2img'); - const [model, setModel] = useState(config.model.default); - const [platform, setPlatform] = useState(config.platform.default); - - const models = useQuery('models', async () => client.models(), { - staleTime: STALE_TIME, - }); - const platforms = useQuery('platforms', async () => client.platforms(), { - staleTime: STALE_TIME, - }); return ( @@ -45,28 +22,7 @@ export function OnnxWeb(props: OnnxWebProps) { - - { - setModel(value); - }} - /> - { - setPlatform(value); - }} - /> - + @@ -80,16 +36,16 @@ export function OnnxWeb(props: OnnxWebProps) { - + - + - + - + diff --git a/gui/src/components/QueryList.tsx b/gui/src/components/QueryList.tsx index 378a13c4..72eb6dd3 100644 --- a/gui/src/components/QueryList.tsx +++ b/gui/src/components/QueryList.tsx @@ -3,18 +3,42 @@ import { FormControl, InputLabel, MenuItem, Select } from '@mui/material'; import * as React from 'react'; import { UseQueryResult } from 'react-query'; -export interface QueryListProps { +export interface QueryListComplete { + result: UseQueryResult>; +} + +export interface QueryListFilter { + result: UseQueryResult; + selector: (result: T) => Array; +} + +export interface QueryListProps { id: string; labels: Record; name: string; - result: UseQueryResult>; value: string; + query: QueryListComplete | QueryListFilter; + onChange?: (value: string) => void; } -export function QueryList(props: QueryListProps) { - const { labels, result, value } = props; +export function hasFilter(query: QueryListComplete | QueryListFilter): query is QueryListFilter { + return Reflect.has(query, 'selector'); +} + +export function filterQuery(query: QueryListComplete | QueryListFilter): Array { + if (hasFilter(query)) { + const data = mustExist(query.result.data); + return (query as QueryListFilter).selector(data); + } else { + return mustExist(query.result.data); + } +} + +export function QueryList(props: QueryListProps) { + const { labels, query, value } = props; + const { result } = query; if (result.status === 'error') { if (result.error instanceof Error) { @@ -34,7 +58,8 @@ export function QueryList(props: QueryListProps) { // else: success const labelID = `query-list-${props.id}-labels`; - const data = mustExist(result.data); + const data = filterQuery(query); + return {props.name}