From f14f197264ba2b1f2bc9b396eb02490a341e7d9b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 21 Jul 2023 17:38:01 -0500 Subject: [PATCH] feat(gui): move model controls into each tab --- gui/src/client/types.ts | 4 +- gui/src/components/OnnxWeb.tsx | 4 - gui/src/components/Profiles.tsx | 84 +-- gui/src/components/card/ImageCard.tsx | 2 +- gui/src/components/control/HighresControl.tsx | 20 +- gui/src/components/control/ImageControl.tsx | 113 ++-- gui/src/components/control/ModelControl.tsx | 224 +++----- gui/src/components/control/UpscaleControl.tsx | 20 +- gui/src/components/input/MaskCanvas.tsx | 12 +- gui/src/components/input/PromptInput.tsx | 49 +- gui/src/components/tab/Blend.tsx | 16 +- gui/src/components/tab/Img2Img.tsx | 31 +- gui/src/components/tab/Inpaint.tsx | 51 +- gui/src/components/tab/Txt2Img.tsx | 33 +- gui/src/components/tab/Upscale.tsx | 39 +- gui/src/main.tsx | 12 +- gui/src/state.ts | 511 +++++++++++------- 17 files changed, 686 insertions(+), 539 deletions(-) diff --git a/gui/src/client/types.ts b/gui/src/client/types.ts index 8de792b9..ee5907e6 100644 --- a/gui/src/client/types.ts +++ b/gui/src/client/types.ts @@ -140,9 +140,7 @@ export interface UpscaleParams { /** * Parameters for upscale requests. */ -export interface UpscaleReqParams { - prompt: string; - negativePrompt?: string; +export interface UpscaleReqParams extends BaseImgParams { source: Blob; } diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index ca69a816..dcc8634f 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -8,7 +8,6 @@ import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { StateContext } from '../state.js'; -import { ModelControl } from './control/ModelControl.js'; import { ImageHistory } from './ImageHistory.js'; import { Logo } from './Logo.js'; import { Blend } from './tab/Blend.js'; @@ -43,9 +42,6 @@ export function OnnxWeb() { - - - { diff --git a/gui/src/components/Profiles.tsx b/gui/src/components/Profiles.tsx index 65ffb5bc..e3a9230d 100644 --- a/gui/src/components/Profiles.tsx +++ b/gui/src/components/Profiles.tsx @@ -20,29 +20,35 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { BaseImgParams, Txt2ImgParams } from '../client/types.js'; +import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../client/types.js'; import { StateContext } from '../state.js'; +const { useState, Fragment } = React; + export interface ProfilesProps { + highres: HighresParams; params: BaseImgParams; - setParams: ((params: BaseImgParams) => void) | undefined; + upscale: UpscaleParams; + + setHighres(params: HighresParams): void; + setParams(params: BaseImgParams): void; + setUpscale(params: UpscaleParams): void; } export function Profiles(props: ProfilesProps) { const state = mustExist(useContext(StateContext)); + const profiles = useStore(state, (s) => s.profiles); // eslint-disable-next-line @typescript-eslint/unbound-method const saveProfile = useStore(state, (s) => s.saveProfile); // eslint-disable-next-line @typescript-eslint/unbound-method const removeProfile = useStore(state, (s) => s.removeProfile); - const profiles = useStore(state, (s) => s.profiles); - const highres = useStore(state, (s) => s.highres); - const upscale = useStore(state, (s) => s.upscale); - const [dialogOpen, setDialogOpen] = React.useState(false); - const [profileName, setProfileName] = React.useState(''); + + const [dialogOpen, setDialogOpen] = useState(false); + const [profileName, setProfileName] = useState(''); const { t } = useTranslation(); - return <> + return setDialogOpen(true)}> - )} onChange={(event, value) => { - if (doesExist(value) && doesExist(props.setParams)) { + if (doesExist(value)) { props.setParams({ ...value.params }); @@ -138,8 +118,8 @@ export function Profiles(props: ProfilesProps) { saveProfile({ params: props.params, name: profileName, - highResParams: highres, - upscaleParams: upscale, + highResParams: props.highres, + upscaleParams: props.upscale, }); setDialogOpen(false); setProfileName(''); @@ -147,7 +127,33 @@ export function Profiles(props: ProfilesProps) { >{t('profile.save')} - ; + + ; } export async function loadParamsFromFile(file: File): Promise> { @@ -276,7 +282,7 @@ export async function parseAutoComment(comment: string): Promise s.setInpaint); // eslint-disable-next-line @typescript-eslint/unbound-method - const setUpscale = useStore(state, (s) => s.setUpscaleTab); + const setUpscale = useStore(state, (s) => s.setUpscale); // eslint-disable-next-line @typescript-eslint/unbound-method const setBlend = useStore(state, (s) => s.setBlend); diff --git a/gui/src/components/control/HighresControl.tsx b/gui/src/components/control/HighresControl.tsx index 3467a947..4d2380c2 100644 --- a/gui/src/components/control/HighresControl.tsx +++ b/gui/src/components/control/HighresControl.tsx @@ -3,17 +3,21 @@ import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, import * as React from 'react'; import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; -import { useStore } from 'zustand'; -import { ConfigContext, StateContext } from '../../state.js'; +import { HighresParams } from '../../client/types.js'; +import { ConfigContext } from '../../state.js'; import { NumericField } from '../input/NumericField.js'; -export function HighresControl() { - const { params } = mustExist(useContext(ConfigContext)); - const state = mustExist(useContext(StateContext)); - const highres = useStore(state, (s) => s.highres); +export interface HighresControlProps { + highres: HighresParams; + setHighres(params: Partial): void; +} + +export function HighresControl(props: HighresControlProps) { // eslint-disable-next-line @typescript-eslint/unbound-method - const setHighres = useStore(state, (s) => s.setHighres); + const { highres, setHighres } = props; + + const { params } = mustExist(useContext(ConfigContext)); const { t } = useTranslation(); return @@ -22,7 +26,7 @@ export function HighresControl() { control={ { + onChange={(_event) => { setHighres({ enabled: highres.enabled === false, }); diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index c6786d39..cf40979d 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -13,21 +13,21 @@ import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../sta import { NumericField } from '../input/NumericField.js'; import { PromptInput } from '../input/PromptInput.js'; import { QueryList } from '../input/QueryList.js'; -import { Profiles } from '../Profiles.js'; export interface ImageControlProps { - selector: (state: OnnxState) => BaseImgParams; - - onChange?: (params: BaseImgParams) => void; + onChange(params: BaseImgParams): void; + selector(state: OnnxState): BaseImgParams; } /** * Doesn't need to use state directly, the parent component knows which params to pass */ export function ImageControl(props: ImageControlProps) { + // eslint-disable-next-line @typescript-eslint/unbound-method + const { onChange, selector } = props; const { params } = mustExist(useContext(ConfigContext)); const state = mustExist(useContext(StateContext)); - const controlState = useStore(state, props.selector); + const controlState = useStore(state, selector); const { t } = useTranslation(); const client = mustExist(useContext(ClientContext)); @@ -40,7 +40,6 @@ export function ImageControl(props: ImageControlProps) { return - { - if (doesExist(props.onChange)) { - props.onChange({ + if (doesExist(onChange)) { + onChange({ ...controlState, scheduler: value, }); @@ -66,8 +65,8 @@ export function ImageControl(props: ImageControlProps) { step={params.eta.step} value={controlState.eta} onChange={(eta) => { - if (doesExist(props.onChange)) { - props.onChange({ + if (doesExist(onChange)) { + onChange({ ...controlState, eta, }); @@ -82,8 +81,8 @@ export function ImageControl(props: ImageControlProps) { step={params.cfg.step} value={controlState.cfg} onChange={(cfg) => { - if (doesExist(props.onChange)) { - props.onChange({ + if (doesExist(onChange)) { + onChange({ ...controlState, cfg, }); @@ -97,12 +96,10 @@ export function ImageControl(props: ImageControlProps) { step={params.steps.step} value={controlState.steps} onChange={(steps) => { - if (doesExist(props.onChange)) { - props.onChange({ - ...controlState, - steps, - }); - } + onChange({ + ...controlState, + steps, + }); }} /> { - if (doesExist(props.onChange)) { - props.onChange({ - ...controlState, - seed, - }); - } + onChange({ + ...controlState, + seed, + }); }} /> - + }} + /> + result.diffusion, + }} + value={model.model} + onChange={(newModel) => { + setModel({ + model: newModel, + }); + }} + /> + result.upscaling, + }} + value={model.upscaling} + onChange={(upscaling) => { + setModel({ + upscaling, + }); + }} + /> + result.correction, + }} + value={model.correction} + onChange={(correction) => { + setModel({ + correction, + }); + }} + /> + ; } diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx index 4eff8c83..f9d739b1 100644 --- a/gui/src/components/control/UpscaleControl.tsx +++ b/gui/src/components/control/UpscaleControl.tsx @@ -3,17 +3,21 @@ import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, import * as React from 'react'; import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; -import { useStore } from 'zustand'; -import { ConfigContext, StateContext } from '../../state.js'; +import { UpscaleParams } from '../../client/types.js'; +import { ConfigContext } from '../../state.js'; import { NumericField } from '../input/NumericField.js'; -export function UpscaleControl() { - const { params } = mustExist(useContext(ConfigContext)); - const state = mustExist(useContext(StateContext)); - const upscale = useStore(state, (s) => s.upscale); +export interface UpscaleControlProps { + upscale: UpscaleParams; + setUpscale(params: Partial): void; +} + +export function UpscaleControl(props: UpscaleControlProps) { // eslint-disable-next-line @typescript-eslint/unbound-method - const setUpscale = useStore(state, (s) => s.setUpscale); + const { upscale, setUpscale } = props; + + const { params } = mustExist(useContext(ConfigContext)); const { t } = useTranslation(); return @@ -22,7 +26,7 @@ export function UpscaleControl() { control={ { + onChange={(_event) => { setUpscale({ enabled: upscale.enabled === false, }); diff --git a/gui/src/components/input/MaskCanvas.tsx b/gui/src/components/input/MaskCanvas.tsx index e59ddae0..0bc4673c 100644 --- a/gui/src/components/input/MaskCanvas.tsx +++ b/gui/src/components/input/MaskCanvas.tsx @@ -4,8 +4,8 @@ import { Button, Stack, Typography } from '@mui/material'; import { throttle } from 'lodash'; import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react'; import { useTranslation } from 'react-i18next'; -import { useStore } from 'zustand'; +import { BrushParams } from '../../client/types.js'; import { SAVE_TIME } from '../../config.js'; import { ConfigContext, LoggerContext, StateContext } from '../../state.js'; import { imageFromBlob } from '../../utils.js'; @@ -36,14 +36,17 @@ export interface Point { } export interface MaskCanvasProps { + brush: BrushParams; source?: Maybe; mask?: Maybe; - onSave: (blob: Blob) => void; + onSave(blob: Blob): void; + setBrush(brush: Partial): void; } export function MaskCanvas(props: MaskCanvasProps) { - const { source, mask } = props; + // eslint-disable-next-line @typescript-eslint/unbound-method + const { source, mask, brush, setBrush } = props; const { params } = mustExist(useContext(ConfigContext)); const logger = mustExist(useContext(LoggerContext)); @@ -202,9 +205,6 @@ export function MaskCanvas(props: MaskCanvasProps) { }); const state = mustExist(useContext(StateContext)); - const brush = useStore(state, (s) => s.brush); - // eslint-disable-next-line @typescript-eslint/unbound-method - const setBrush = useStore(state, (s) => s.setBrush); const { t } = useTranslation(); useEffect(() => { diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index 475c6c68..2a6a7287 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -1,16 +1,23 @@ -import { doesExist, Maybe } from '@apextoaster/js-utils'; +import { doesExist, mustExist } from '@apextoaster/js-utils'; import { TextField } from '@mui/material'; import { Stack } from '@mui/system'; +import { useQuery } from '@tanstack/react-query'; import * as React from 'react'; import { useTranslation } from 'react-i18next'; +import { QueryMenu } from '../input/QueryMenu.js'; +import { STALE_TIME } from '../../config.js'; +import { ClientContext } from '../../state.js'; + +const { useContext } = React; + export interface PromptValue { prompt: string; negativePrompt?: string; } export interface PromptInputProps extends PromptValue { - onChange?: Maybe<(value: PromptValue) => void>; + onChange: (value: PromptValue) => void; } export const PROMPT_GROUP = 75; @@ -29,12 +36,24 @@ export function PromptInput(props: PromptInputProps) { const tokens = splitPrompt(prompt); const groups = Math.ceil(tokens.length / PROMPT_GROUP); + const client = mustExist(useContext(ClientContext)); + const models = useQuery(['models'], async () => client.models(), { + staleTime: STALE_TIME, + }); + const { t } = useTranslation(); const helper = t('input.prompt.tokens', { groups, tokens: tokens.length, }); + function addToken(type: string, name: string, weight = 1.0) { + props.onChange({ + prompt: `<${type}:${name}:1.0> ${prompt}`, + negativePrompt, + }); + } + return + + result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), + }} + onSelect={(name) => { + addToken('inversion', name); + }} + /> + result.networks.filter((network) => network.type === 'lora').map((network) => network.name), + }} + onSelect={(name) => { + addToken('lora', name); + }} + /> + ; } diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index 7d355742..f54159eb 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -15,12 +15,12 @@ import { MaskCanvas } from '../input/MaskCanvas.js'; export function Blend() { async function uploadSource() { - const { model, blend, upscale } = state.getState(); - const { image, retry } = await client.blend(model, { + const { blend, blendModel, blendUpscale } = state.getState(); + const { image, retry } = await client.blend(blendModel, { ...blend, mask: mustExist(blend.mask), sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist - }, upscale); + }, blendUpscale); pushHistory(image, retry); } @@ -32,10 +32,16 @@ export function Blend() { }); const state = mustExist(useContext(StateContext)); + const brush = useStore(state, (s) => s.blendBrush); const blend = useStore(state, (s) => s.blend); + const upscale = useStore(state, (s) => s.blendUpscale); // eslint-disable-next-line @typescript-eslint/unbound-method const setBlend = useStore(state, (s) => s.setBlend); // eslint-disable-next-line @typescript-eslint/unbound-method + const setBrush = useStore(state, (s) => s.setBlendBrush); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setUpscale = useStore(state, (s) => s.setBlendUpscale); + // eslint-disable-next-line @typescript-eslint/unbound-method const pushHistory = useStore(state, (s) => s.pushHistory); const { t } = useTranslation(); @@ -61,6 +67,7 @@ export function Blend() { /> )} { @@ -68,8 +75,9 @@ export function Blend() { mask, }); }} + setBrush={setBrush} /> - +