diff --git a/gui/src/components/input/EditableList.tsx b/gui/src/components/input/EditableList.tsx index 16541584..5c6cf476 100644 --- a/gui/src/components/input/EditableList.tsx +++ b/gui/src/components/input/EditableList.tsx @@ -3,25 +3,25 @@ import { Button, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { useStore } from 'zustand'; -import { OnnxState, StateContext } from '../../state'; +import { OnnxState, StateContext } from '../../state.js'; const { useContext, useState, memo, useMemo } = React; export interface EditableListProps { - // items: Array; selector: (s: OnnxState) => Array; newItem: (l: string, s: string) => T; - // removeItem: (t: T) => void; + removeItem: (t: T) => void; renderItem: (props: { model: T; onChange: (t: T) => void; + onRemove: (t: T) => void; }) => React.ReactElement; setItem: (t: T) => void; } export function EditableList(props: EditableListProps) { - const { newItem, renderItem, setItem, selector } = props; + const { newItem, removeItem, renderItem, setItem, selector } = props; const state = mustExist(useContext(StateContext)); const items = useStore(state, selector); @@ -30,15 +30,14 @@ export function EditableList(props: EditableListProps) { const RenderMemo = useMemo(() => memo(renderItem), [renderItem]); return - {items.map((model, idx) => + {items.map((model, idx) => - - )} + )} void; + onRemove: (model: CorrectionModel) => void; } export function CorrectionModelInput(props: CorrectionModelInputProps) { - const { model, onChange } = props; + const { key, model, onChange, onRemove } = props; - return + return { @@ -37,7 +39,7 @@ export function CorrectionModelInput(props: CorrectionModelInputProps) { onChange={(selection) => { onChange({ ...model, - format: selection.target.value as 'safetensors', + format: selection.target.value as ModelFormat, }); }} > @@ -50,12 +52,13 @@ export function CorrectionModelInput(props: CorrectionModelInputProps) { onChange={(selection) => { onChange({ ...model, - model: selection.target.value as 'codeformer', + model: selection.target.value as CorrectionArch, }); }} > Codeformer GFPGAN + ; } diff --git a/gui/src/components/input/model/DiffusionModel.tsx b/gui/src/components/input/model/DiffusionModel.tsx index 7a416d01..c755c285 100644 --- a/gui/src/components/input/model/DiffusionModel.tsx +++ b/gui/src/components/input/model/DiffusionModel.tsx @@ -1,18 +1,20 @@ -import { MenuItem, Select, Stack, TextField } from '@mui/material'; +import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; -import { DiffusionModel } from '../../../types.js'; +import { DiffusionModel, ModelFormat } from '../../../types.js'; export interface DiffusionModelInputProps { + key?: number | string; model: DiffusionModel; onChange: (model: DiffusionModel) => void; + onRemove: (model: DiffusionModel) => void; } export function DiffusionModelInput(props: DiffusionModelInputProps) { - const { model, onChange } = props; + const { key, model, onChange, onRemove } = props; - return + return { onChange({ ...model, @@ -28,11 +30,12 @@ export function DiffusionModelInput(props: DiffusionModelInputProps) { + ; } diff --git a/gui/src/components/input/model/ExtraNetwork.tsx b/gui/src/components/input/model/ExtraNetwork.tsx index 1cc652ef..478ef5c5 100644 --- a/gui/src/components/input/model/ExtraNetwork.tsx +++ b/gui/src/components/input/model/ExtraNetwork.tsx @@ -1,18 +1,20 @@ -import { MenuItem, Select, Stack, TextField } from '@mui/material'; +import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; -import { ExtraNetwork } from '../../../types.js'; +import { ExtraNetwork, ModelFormat, NetworkModel, NetworkType } from '../../../types.js'; export interface ExtraNetworkInputProps { + key?: number | string; model: ExtraNetwork; onChange: (model: ExtraNetwork) => void; + onRemove: (model: ExtraNetwork) => void; } export function ExtraNetworkInput(props: ExtraNetworkInputProps) { - const { model, onChange } = props; + const { key, model, onChange, onRemove } = props; - return + return { onChange({ ...model, - format: selection.target.value as 'safetensors', + format: selection.target.value as ModelFormat, }); }} > @@ -50,7 +52,7 @@ export function ExtraNetworkInput(props: ExtraNetworkInputProps) { { onChange({ ...model, - model: selection.target.value as 'sd-scripts', + model: selection.target.value as NetworkModel, }); }}> LoRA - sd-scripts TI - concept TI - embeddings + ; } diff --git a/gui/src/components/input/model/ExtraSource.tsx b/gui/src/components/input/model/ExtraSource.tsx index 43230a5c..753b559a 100644 --- a/gui/src/components/input/model/ExtraSource.tsx +++ b/gui/src/components/input/model/ExtraSource.tsx @@ -1,18 +1,20 @@ +import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; -import { MenuItem, Select, Stack, TextField } from '@mui/material'; -import { ExtraSource } from '../../../types'; +import { AnyFormat, ExtraSource } from '../../../types.js'; export interface ExtraSourceInputProps { + key?: number | string; model: ExtraSource; onChange: (model: ExtraSource) => void; + onRemove: (model: ExtraSource) => void; } export function ExtraSourceInput(props: ExtraSourceInputProps) { - const { model, onChange } = props; + const { key, model, onChange, onRemove } = props; - return + return { onChange({ ...model, @@ -31,7 +33,7 @@ export function ExtraSourceInput(props: ExtraSourceInputProps) { onChange={(selection) => { onChange({ ...model, - format: selection.target.value as 'safetensors', + format: selection.target.value as AnyFormat, }); }} > @@ -46,5 +48,6 @@ export function ExtraSourceInput(props: ExtraSourceInputProps) { dest: event.target.value, }); }} /> + ; } diff --git a/gui/src/components/input/model/UpscalingModel.tsx b/gui/src/components/input/model/UpscalingModel.tsx index 020838a4..57617474 100644 --- a/gui/src/components/input/model/UpscalingModel.tsx +++ b/gui/src/components/input/model/UpscalingModel.tsx @@ -1,19 +1,21 @@ -import { MenuItem, Select, Stack, TextField } from '@mui/material'; +import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; -import { UpscalingModel } from '../../../types.js'; +import { ModelFormat, UpscalingArch, UpscalingModel } from '../../../types.js'; import { NumericField } from '../NumericField.js'; export interface UpscalingModelInputProps { + key?: number | string; model: UpscalingModel; onChange: (model: UpscalingModel) => void; + onRemove: (model: UpscalingModel) => void; } export function UpscalingModelInput(props: UpscalingModelInputProps) { - const { model, onChange } = props; + const { key, model, onChange, onRemove } = props; - return + return { onChange({ ...model, @@ -29,7 +31,7 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) { { onChange({ ...model, - model: selection.target.value as 'bsrgan', + model: selection.target.value as UpscalingArch, }); }}> BSRGAN @@ -58,5 +60,6 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) { }); }} /> + ; } diff --git a/gui/src/components/tab/Models.tsx b/gui/src/components/tab/Models.tsx index 546c14d9..b4043719 100644 --- a/gui/src/components/tab/Models.tsx +++ b/gui/src/components/tab/Models.tsx @@ -7,7 +7,7 @@ import { useStore } from 'zustand'; import { STALE_TIME } from '../../config.js'; import { ClientContext, OnnxState, StateContext } from '../../state.js'; -import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, SafetensorFormat, UpscalingModel } from '../../types.js'; +import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, NetworkModel, NetworkType, SafetensorFormat, UpscalingModel } from '../../types.js'; import { EditableList } from '../input/EditableList'; import { CorrectionModelInput } from '../input/model/CorrectionModel.js'; import { DiffusionModelInput } from '../input/model/DiffusionModel.js'; @@ -79,6 +79,16 @@ export function Models() { const setExtraSource = useStore(state, (s) => s.setExtraSource); // eslint-disable-next-line @typescript-eslint/unbound-method const setUpscalingModel = useStore(state, (s) => s.setUpscalingModel); + // eslint-disable-next-line @typescript-eslint/unbound-method + const removeCorrectionModel = useStore(state, (s) => s.removeCorrectionModel); + // eslint-disable-next-line @typescript-eslint/unbound-method + const removeDiffusionModel = useStore(state, (s) => s.removeDiffusionModel); + // eslint-disable-next-line @typescript-eslint/unbound-method + const removeExtraNetwork = useStore(state, (s) => s.removeExtraNetwork); + // eslint-disable-next-line @typescript-eslint/unbound-method + const removeExtraSource = useStore(state, (s) => s.removeExtraSource); + // eslint-disable-next-line @typescript-eslint/unbound-method + const removeUpscalingModel = useStore(state, (s) => s.removeUpscalingModel); const client = mustExist(useContext(ClientContext)); const result = useQuery(['extras'], async () => client.extras(), { @@ -123,7 +133,7 @@ export function Models() { name: kebabCase(l), source: s, })} - // removeItem={(m) => { /* TODO */ }} + removeItem={(m) => removeDiffusionModel(m)} renderItem={DiffusionModelInput} setItem={(model) => setDiffusionModel(model)} /> @@ -142,6 +152,7 @@ export function Models() { name: kebabCase(l), source: s, })} + removeItem={(m) => removeCorrectionModel(m)} renderItem={CorrectionModelInput} setItem={(model) => setCorrectionModel(model)} /> @@ -161,6 +172,7 @@ export function Models() { scale: 4, source: s, })} + removeItem={(m) => removeUpscalingModel(m)} renderItem={UpscalingModelInput} setItem={(model) => setUpscalingModel(model)} /> @@ -176,11 +188,12 @@ export function Models() { newItem={(l, s) => ({ format: 'safetensors' as SafetensorFormat, label: l, - model: 'embeddings' as const, + model: 'embeddings' as NetworkModel, name: kebabCase(l), source: s, - type: 'inversion' as const, + type: 'inversion' as NetworkType, })} + removeItem={(m) => removeExtraNetwork(m)} renderItem={ExtraNetworkInput} setItem={(model) => setExtraNetwork(model)} /> @@ -199,6 +212,7 @@ export function Models() { name: kebabCase(l), source: s, })} + removeItem={(m) => removeExtraSource(m)} renderItem={ExtraSourceInput} setItem={(model) => setExtraSource(model)} /> diff --git a/gui/src/state.ts b/gui/src/state.ts index afb0d04e..12d78334 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -65,6 +65,12 @@ interface ExtraSlice { setExtraNetwork(model: ExtraNetwork): void; setExtraSource(model: ExtraSource): void; setUpscalingModel(model: UpscalingModel): void; + + removeCorrectionModel(model: CorrectionModel): void; + removeDiffusionModel(model: DiffusionModel): void; + removeExtraNetwork(model: ExtraNetwork): void; + removeExtraSource(model: ExtraSource): void; + removeUpscalingModel(model: UpscalingModel): void; } interface HistorySlice { @@ -679,6 +685,70 @@ export function createStateSlices(server: ServerParams) { }; }); }, + removeCorrectionModel(model) { + set((prev) => { + const correction = prev.extras.correction.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + + }, + removeDiffusionModel(model) { + set((prev) => { + const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + + }, + removeExtraNetwork(model) { + set((prev) => { + const networks = prev.extras.networks.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + + }, + removeExtraSource(model) { + set((prev) => { + const sources = prev.extras.sources.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + + }, + removeUpscalingModel(model) { + set((prev) => { + const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, }); return { diff --git a/gui/src/types.ts b/gui/src/types.ts index d203bc9a..8d1907b6 100644 --- a/gui/src/types.ts +++ b/gui/src/types.ts @@ -1,6 +1,16 @@ export type TorchFormat = 'bin' | 'ckpt' | 'pt' | 'pth'; export type OnnxFormat = 'onnx'; export type SafetensorFormat = 'safetensors'; +export type TensorFormat = TorchFormat | SafetensorFormat; +export type ModelFormat = TensorFormat | OnnxFormat; +export type MarkupFormat = 'json' | 'yaml'; +export type AnyFormat = MarkupFormat | ModelFormat; + +export type UpscalingArch = 'bsrgan' | 'resrgan' | 'swinir'; +export type CorrectionArch = 'codeformer' | 'gfpgan'; + +export type NetworkType = 'inversion' | 'lora'; +export type NetworkModel = 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts'; export interface BaseModel { /** @@ -35,17 +45,17 @@ export interface DiffusionModel extends BaseModel { } export interface UpscalingModel extends BaseModel { - model?: 'bsrgan' | 'resrgan' | 'swinir'; + model?: UpscalingArch; scale: number; } export interface CorrectionModel extends BaseModel { - model?: 'codeformer' | 'gfpgan'; + model?: CorrectionArch; } export interface ExtraNetwork extends BaseModel { - model: 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts'; - type: 'inversion' | 'lora'; + model: NetworkModel; + type: NetworkType; } export interface ExtraSource {