From 1e66ffd25257659ca62d5b15b9428a0dc9bcb5fc Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 7 May 2023 23:02:56 -0500 Subject: [PATCH] clean up react hooks in models tab --- gui/src/client/api.ts | 27 +++- gui/src/client/local.ts | 4 +- gui/src/client/types.ts | 4 +- gui/src/components/control/ModelControl.tsx | 5 +- gui/src/components/input/EditableList.tsx | 41 +++-- .../input/model/CorrectionModel.tsx | 52 ++++++- .../components/input/model/DiffusionModel.tsx | 27 +++- .../components/input/model/ExtraNetwork.tsx | 55 ++++++- .../components/input/model/ExtraSource.tsx | 39 ++++- .../components/input/model/UpscalingModel.tsx | 53 ++++++- gui/src/components/tab/Models.tsx | 143 ++++++++++++------ gui/src/main.tsx | 6 +- gui/src/state.ts | 106 ++++++++++++- 13 files changed, 468 insertions(+), 94 deletions(-) diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 9c7b1c9d..efcfd7dc 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -1,5 +1,5 @@ /* eslint-disable max-lines */ -import { doesExist, InvalidArgumentError } from '@apextoaster/js-utils'; +import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; import { ServerParams } from '../config.js'; import { range } from '../utils.js'; @@ -134,7 +134,7 @@ export function appendHighresToURL(url: URL, highres: HighresParams) { /** * Make an API client using the given API root and fetch client. */ -export function makeClient(root: string, f = fetch): ApiClient { +export function makeClient(root: string, token: Maybe = undefined, f = fetch): ApiClient { function parseRequest(url: URL, options: RequestInit): Promise { return f(url, options).then((res) => parseApiResponse(root, res)); } @@ -142,12 +142,21 @@ export function makeClient(root: string, f = fetch): ApiClient { return { async extras(): Promise { const path = makeApiUrl(root, 'extras'); + + if (doesExist(token)) { + path.searchParams.append('token', token); + } + const res = await f(path); return await res.json() as ExtrasFile; }, async writeExtras(extras: ExtrasFile): Promise { const path = makeApiUrl(root, 'extras'); + if (doesExist(token)) { + path.searchParams.append('token', token); + } + const res = await f(path, { body: JSON.stringify(extras), method: 'PUT', @@ -454,18 +463,24 @@ export function makeClient(root: string, f = fetch): ApiClient { throw new InvalidArgumentError('unknown request type'); } }, - async restart(token: string): Promise { + async restart(): Promise { const path = makeApiUrl(root, 'restart'); - path.searchParams.append('token', token); + + if (doesExist(token)) { + path.searchParams.append('token', token); + } const res = await f(path, { method: 'POST', }); return res.status === STATUS_SUCCESS; }, - async status(token: string): Promise> { + async status(): Promise> { const path = makeApiUrl(root, 'status'); - path.searchParams.append('token', token); + + if (doesExist(token)) { + path.searchParams.append('token', token); + } const res = await f(path); return res.json(); diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 16f681c1..97f785a8 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -69,10 +69,10 @@ export const LOCAL_CLIENT = { async strings() { return {}; }, - async restart(token) { + async restart() { throw new NoServerError(); }, - async status(token) { + async status() { throw new NoServerError(); } } as ApiClient; diff --git a/gui/src/client/types.ts b/gui/src/client/types.ts index ba59c5e9..8de792b9 100644 --- a/gui/src/client/types.ts +++ b/gui/src/client/types.ts @@ -364,10 +364,10 @@ export interface ApiClient { /** * Restart the image job workers. */ - restart(token: string): Promise; + restart(): Promise; /** * Check the status of the image job workers. */ - status(token: string): Promise>; + status(): Promise>; } diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index 10bc8acf..1d837368 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -21,12 +21,9 @@ export function ModelControl() { const setModel = useStore(state, (s) => s.setModel); const { t } = useTranslation(); - // get token from query string - const query = new URLSearchParams(window.location.search); - const token = query.get('token'); const [hash, _setHash] = useHash(); - const restart = useMutation(['restart'], async () => client.restart(mustExist(token))); + const restart = useMutation(['restart'], async () => client.restart()); const models = useQuery(['models'], async () => client.models(), { staleTime: STALE_TIME, }); diff --git a/gui/src/components/input/EditableList.tsx b/gui/src/components/input/EditableList.tsx index a91da836..5bc0c51d 100644 --- a/gui/src/components/input/EditableList.tsx +++ b/gui/src/components/input/EditableList.tsx @@ -1,28 +1,44 @@ +import { mustExist } from '@apextoaster/js-utils'; import { Button, Stack, TextField } from '@mui/material'; import * as React from 'react'; +import { useStore } from 'zustand'; -const { useState } = React; +import { OnnxState, StateContext } from '../../state'; + +const { useContext, useState } = React; export interface EditableListProps { - items: Array; + // items: Array; + selector: (s: OnnxState) => Array; newItem: (l: string, s: string) => T; - renderItem: (t: T) => React.ReactElement; - setItems: (ts: Array) => void; + // removeItem: (t: T) => void; + renderItem: (props: { + model: T; + onChange: (t: T) => void; + }) => React.ReactElement; + setItem: (t: T) => void; } export function EditableList(props: EditableListProps) { - const { items, newItem, renderItem, setItems } = props; + const state = mustExist(useContext(StateContext)); + const items = useStore(state, props.selector); + + const { newItem, renderItem, setItem } = props; const [nextLabel, setNextLabel] = useState(''); const [nextSource, setNextSource] = useState(''); return - {items.map((it, idx) => - {renderItem(it)} - + {items.map((model, idx) => + {renderItem({ + model, + onChange(t) { + setItem(t); + }, + })} + )} (props: EditableListProps) { onChange={(event) => setNextSource(event.target.value)} /> ; diff --git a/gui/src/components/input/model/CorrectionModel.tsx b/gui/src/components/input/model/CorrectionModel.tsx index 3af0ab7a..2e3b968c 100644 --- a/gui/src/components/input/model/CorrectionModel.tsx +++ b/gui/src/components/input/model/CorrectionModel.tsx @@ -1,17 +1,61 @@ -import { Stack, TextField } from '@mui/material'; +import { MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { CorrectionModel } from '../../../types'; export interface CorrectionModelInputProps { model: CorrectionModel; + + onChange: (model: CorrectionModel) => void; } export function CorrectionModelInput(props: CorrectionModelInputProps) { - const { model } = props; + const { model, onChange } = props; return - - + { + onChange({ + ...model, + label: event.target.value, + }); + }} + /> + { + onChange({ + ...model, + source: event.target.value, + }); + }} + /> + + ; } diff --git a/gui/src/components/input/model/DiffusionModel.tsx b/gui/src/components/input/model/DiffusionModel.tsx index bd1dcbc7..7a416d01 100644 --- a/gui/src/components/input/model/DiffusionModel.tsx +++ b/gui/src/components/input/model/DiffusionModel.tsx @@ -1,19 +1,36 @@ import { MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; -import { DiffusionModel } from '../../../types'; +import { DiffusionModel } from '../../../types.js'; export interface DiffusionModelInputProps { model: DiffusionModel; + + onChange: (model: DiffusionModel) => void; } export function DiffusionModelInput(props: DiffusionModelInputProps) { - const { model } = props; + const { model, onChange } = props; return - - - { + onChange({ + ...model, + format: selection.target.value as 'ckpt' | 'safetensors', + }); + }}> ckpt safetensors diff --git a/gui/src/components/input/model/ExtraNetwork.tsx b/gui/src/components/input/model/ExtraNetwork.tsx index 62be0d33..7e0cb781 100644 --- a/gui/src/components/input/model/ExtraNetwork.tsx +++ b/gui/src/components/input/model/ExtraNetwork.tsx @@ -1,23 +1,66 @@ import { MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; -import { ExtraNetwork } from '../../../types'; +import { ExtraNetwork } from '../../../types.js'; export interface ExtraNetworkInputProps { model: ExtraNetwork; + + onChange: (model: ExtraNetwork) => void; } export function ExtraNetworkInput(props: ExtraNetworkInputProps) { - const { model } = props; + const { model, onChange } = props; return - - - { + onChange({ + ...model, + format: selection.target.value as 'safetensors', + }); + }} + > + ckpt + safetensors + + - { + onChange({ + ...model, + model: selection.target.value as 'sd-scripts', + }); + }}> LoRA - sd-scripts TI - concept TI - embeddings diff --git a/gui/src/components/input/model/ExtraSource.tsx b/gui/src/components/input/model/ExtraSource.tsx index 88d6d950..ea8b9fa2 100644 --- a/gui/src/components/input/model/ExtraSource.tsx +++ b/gui/src/components/input/model/ExtraSource.tsx @@ -1,17 +1,48 @@ import * as React from 'react'; -import { Stack, TextField } from '@mui/material'; +import { MenuItem, Select, Stack, TextField } from '@mui/material'; import { ExtraSource } from '../../../types'; export interface ExtraSourceInputProps { model: ExtraSource; + + onChange: (model: ExtraSource) => void; } export function ExtraSourceInput(props: ExtraSourceInputProps) { - const { model } = props; + const { model, onChange } = props; return - - + { + onChange({ + ...model, + name: event.target.value, + }); + }} /> + { + onChange({ + ...model, + source: event.target.value, + }); + }} /> + + { + onChange({ + ...model, + dest: event.target.value, + }); + }} /> ; } diff --git a/gui/src/components/input/model/UpscalingModel.tsx b/gui/src/components/input/model/UpscalingModel.tsx index 39fe2556..020838a4 100644 --- a/gui/src/components/input/model/UpscalingModel.tsx +++ b/gui/src/components/input/model/UpscalingModel.tsx @@ -1,17 +1,62 @@ -import { Stack, TextField } from '@mui/material'; +import { MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { UpscalingModel } from '../../../types.js'; +import { NumericField } from '../NumericField.js'; export interface UpscalingModelInputProps { model: UpscalingModel; + + onChange: (model: UpscalingModel) => void; } export function UpscalingModelInput(props: UpscalingModelInputProps) { - const { model } = props; + const { model, onChange } = props; return - - + { + onChange({ + ...model, + label: event.target.value, + }); + }} /> + { + onChange({ + ...model, + source: event.target.value, + }); + }} /> + + + { + onChange({ + ...model, + scale: value, + }); + }} + /> ; } diff --git a/gui/src/components/tab/Models.tsx b/gui/src/components/tab/Models.tsx index 8d207eb1..546c14d9 100644 --- a/gui/src/components/tab/Models.tsx +++ b/gui/src/components/tab/Models.tsx @@ -1,12 +1,13 @@ -import { mustExist } from '@apextoaster/js-utils'; -import { Accordion, AccordionDetails, AccordionSummary, Button, Stack } from '@mui/material'; +import { doesExist, mustExist } from '@apextoaster/js-utils'; +import { Accordion, AccordionDetails, AccordionSummary, Alert, Button, CircularProgress, Stack } from '@mui/material'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import _ from 'lodash'; import * as React from 'react'; import { useStore } from 'zustand'; -import { ClientContext, StateContext } from '../../state.js'; -import { SafetensorFormat } from '../../types.js'; +import { STALE_TIME } from '../../config.js'; +import { ClientContext, OnnxState, StateContext } from '../../state.js'; +import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, SafetensorFormat, UpscalingModel } from '../../types.js'; import { EditableList } from '../input/EditableList'; import { CorrectionModelInput } from '../input/model/CorrectionModel.js'; import { DiffusionModelInput } from '../input/model/DiffusionModel.js'; @@ -14,49 +15,117 @@ import { ExtraNetworkInput } from '../input/model/ExtraNetwork.js'; import { ExtraSourceInput } from '../input/model/ExtraSource.js'; import { UpscalingModelInput } from '../input/model/UpscalingModel.js'; -const { useContext } = React; +const { useContext, useEffect } = React; // eslint-disable-next-line @typescript-eslint/unbound-method const { kebabCase } = _; +function mergeModelLists(local: Array, server: Array) { + const localNames = new Set(local.map((it) => it.name)); + + const merged = [...local]; + for (const model of server) { + if (localNames.has(model.name) === false) { + merged.push(model); + } + } + + return merged; +} + +function mergeModels(local: ExtrasFile, server: ExtrasFile): ExtrasFile { + const merged: ExtrasFile = { + ...server, + correction: mergeModelLists(local.correction, server.correction), + diffusion: mergeModelLists(local.diffusion, server.diffusion), + networks: mergeModelLists(local.networks, server.networks), + sources: mergeModelLists(local.sources, server.sources), + upscaling: mergeModelLists(local.upscaling, server.upscaling), + }; + + return merged; +} + +function selectDiffusionModels(state: OnnxState): Array { + return state.extras.diffusion; +} + +function selectCorrectionModels(state: OnnxState): Array { + return state.extras.correction; +} + +function selectUpscalingModels(state: OnnxState): Array { + return state.extras.upscaling; +} + +function selectExtraNetworks(state: OnnxState): Array { + return state.extras.networks; +} + +function selectExtraSources(state: OnnxState): Array { + return state.extras.sources; +} + export function Models() { const state = mustExist(React.useContext(StateContext)); - const extras = useStore(state, (s) => s.extras); // eslint-disable-next-line @typescript-eslint/unbound-method const setExtras = useStore(state, (s) => s.setExtras); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setCorrectionModel = useStore(state, (s) => s.setCorrectionModel); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setDiffusionModel = useStore(state, (s) => s.setDiffusionModel); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setExtraNetwork = useStore(state, (s) => s.setExtraNetwork); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setExtraSource = useStore(state, (s) => s.setExtraSource); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setUpscalingModel = useStore(state, (s) => s.setUpscalingModel); const client = mustExist(useContext(ClientContext)); - const serverExtras = useQuery(['extras'], async () => client.extras(), { - // staleTime: STALE_TIME, + const result = useQuery(['extras'], async () => client.extras(), { + staleTime: STALE_TIME, }); - async function writeExtras() { - const resp = await client.writeExtras(extras); - } - const query = useQueryClient(); const write = useMutation(writeExtras, { onSuccess: () => query.invalidateQueries([ 'extras' ]), }); + useEffect(() => { + if (result.status === 'success' && doesExist(result.data)) { + setExtras(mergeModels(state.getState().extras, result.data)); + } + }, [result.status]); + + if (result.status === 'error') { + return Error; + } + + if (result.status === 'loading') { + return ; + } + + async function writeExtras() { + const resp = await client.writeExtras(state.getState().extras); + // TODO: do something with resp + } + return Diffusion Models - + selector={selectDiffusionModels} newItem={(l, s) => ({ format: 'safetensors' as SafetensorFormat, label: l, name: kebabCase(l), source: s, })} - renderItem={(t) => } - setItems={(diffusion) => setExtras({ - ...extras, - diffusion, - })} + // removeItem={(m) => { /* TODO */ }} + renderItem={DiffusionModelInput} + setItem={(model) => setDiffusionModel(model)} /> @@ -66,18 +135,15 @@ export function Models() { ({ format: 'safetensors' as SafetensorFormat, label: l, name: kebabCase(l), source: s, })} - renderItem={(t) => } - setItems={(correction) => setExtras({ - ...extras, - correction, - })} + renderItem={CorrectionModelInput} + setItem={(model) => setCorrectionModel(model)} /> @@ -87,7 +153,7 @@ export function Models() { ({ format: 'safetensors' as SafetensorFormat, label: l, @@ -95,11 +161,8 @@ export function Models() { scale: 4, source: s, })} - renderItem={(t) => } - setItems={(upscaling) => setExtras({ - ...extras, - upscaling, - })} + renderItem={UpscalingModelInput} + setItem={(model) => setUpscalingModel(model)} /> @@ -109,7 +172,7 @@ export function Models() { ({ format: 'safetensors' as SafetensorFormat, label: l, @@ -118,11 +181,8 @@ export function Models() { source: s, type: 'inversion' as const, })} - renderItem={(t) => } - setItems={(networks) => setExtras({ - ...extras, - networks, - })} + renderItem={ExtraNetworkInput} + setItem={(model) => setExtraNetwork(model)} /> @@ -132,18 +192,15 @@ export function Models() { ({ format: 'safetensors' as SafetensorFormat, label: l, name: kebabCase(l), source: s, })} - renderItem={(t) => } - setItems={(sources) => setExtras({ - ...extras, - sources, - })} + renderItem={ExtraSourceInput} + setItem={(model) => setExtraSource(model)} /> diff --git a/gui/src/main.tsx b/gui/src/main.tsx index 9fd61acf..5c2217a2 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -143,9 +143,13 @@ export async function main() { // load config from GUI server const config = await loadConfig(); + // get token from query string + const query = new URLSearchParams(window.location.search); + const token = query.get('token'); + // use that to create an API client const root = getApiRoot(config); - const client = makeClient(root); + const client = makeClient(root, token); // prep react-dom const appElement = mustExist(document.getElementById('app')); diff --git a/gui/src/state.ts b/gui/src/state.ts index fc41a96a..afb0d04e 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -24,7 +24,9 @@ import { UpscaleReqParams, } from './client/types.js'; import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js'; -import { ExtrasFile } from './types.js'; +import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types.js'; + +export const MISSING_INDEX = -1; export type Theme = PaletteMode | ''; // tri-state, '' is unset @@ -57,6 +59,12 @@ interface ExtraSlice { extras: ExtrasFile; setExtras(extras: Partial): void; + + setCorrectionModel(model: CorrectionModel): void; + setDiffusionModel(model: DiffusionModel): void; + setExtraNetwork(model: ExtraNetwork): void; + setExtraSource(model: ExtraSource): void; + setUpscalingModel(model: UpscalingModel): void; } interface HistorySlice { @@ -558,6 +566,7 @@ export function createStateSlices(server: ServerParams) { }, }); + // eslint-disable-next-line sonarjs/cognitive-complexity const createExtraSlice: Slice = (set) => ({ extras: { correction: [], @@ -575,6 +584,101 @@ export function createStateSlices(server: ServerParams) { }, })); }, + setCorrectionModel(model) { + set((prev) => { + const correction = [...prev.extras.correction]; + const exists = correction.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + correction.push(model); + } else { + correction[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + }, + setDiffusionModel(model) { + set((prev) => { + const diffusion = [...prev.extras.diffusion]; + const exists = diffusion.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + diffusion.push(model); + } else { + diffusion[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + }, + setExtraNetwork(model) { + set((prev) => { + const networks = [...prev.extras.networks]; + const exists = networks.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + networks.push(model); + } else { + networks[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + }, + setExtraSource(model) { + set((prev) => { + const sources = [...prev.extras.sources]; + const exists = sources.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + sources.push(model); + } else { + sources[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + }, + setUpscalingModel(model) { + set((prev) => { + const upscaling = [...prev.extras.upscaling]; + const exists = upscaling.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + upscaling.push(model); + } else { + upscaling[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, }); return {