optimize editable lists, add arch types
This commit is contained in:
parent
534fc70e29
commit
1566ceef7f
|
@ -3,25 +3,25 @@ import { Button, Stack, TextField } from '@mui/material';
|
|||
import * as React from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { OnnxState, StateContext } from '../../state';
|
||||
import { OnnxState, StateContext } from '../../state.js';
|
||||
|
||||
const { useContext, useState, memo, useMemo } = React;
|
||||
|
||||
export interface EditableListProps<T> {
|
||||
// items: Array<T>;
|
||||
selector: (s: OnnxState) => Array<T>;
|
||||
|
||||
newItem: (l: string, s: string) => T;
|
||||
// removeItem: (t: T) => void;
|
||||
removeItem: (t: T) => void;
|
||||
renderItem: (props: {
|
||||
model: T;
|
||||
onChange: (t: T) => void;
|
||||
onRemove: (t: T) => void;
|
||||
}) => React.ReactElement;
|
||||
setItem: (t: T) => void;
|
||||
}
|
||||
|
||||
export function EditableList<T>(props: EditableListProps<T>) {
|
||||
const { newItem, renderItem, setItem, selector } = props;
|
||||
const { newItem, removeItem, renderItem, setItem, selector } = props;
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const items = useStore(state, selector);
|
||||
|
@ -30,15 +30,14 @@ export function EditableList<T>(props: EditableListProps<T>) {
|
|||
const RenderMemo = useMemo(() => memo(renderItem), [renderItem]);
|
||||
|
||||
return <Stack spacing={2}>
|
||||
{items.map((model, idx) => <Stack direction='row' key={idx} spacing={2}>
|
||||
{items.map((model, idx) =>
|
||||
<RenderMemo
|
||||
key={idx}
|
||||
model={model}
|
||||
onChange={setItem}
|
||||
onRemove={removeItem}
|
||||
/>
|
||||
<Button onClick={() => {
|
||||
// removeItem(model);
|
||||
}}>Remove</Button>
|
||||
</Stack>)}
|
||||
)}
|
||||
<Stack direction='row' spacing={2}>
|
||||
<TextField
|
||||
label='Label'
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
import { MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
|
||||
import { CorrectionModel } from '../../../types';
|
||||
import { CorrectionArch, CorrectionModel, ModelFormat } from '../../../types.js';
|
||||
|
||||
export interface CorrectionModelInputProps {
|
||||
key?: number | string;
|
||||
model: CorrectionModel;
|
||||
|
||||
onChange: (model: CorrectionModel) => void;
|
||||
onRemove: (model: CorrectionModel) => void;
|
||||
}
|
||||
|
||||
export function CorrectionModelInput(props: CorrectionModelInputProps) {
|
||||
const { model, onChange } = props;
|
||||
const { key, model, onChange, onRemove } = props;
|
||||
|
||||
return <Stack direction='row' spacing={2}>
|
||||
return <Stack direction='row' spacing={2} key={key}>
|
||||
<TextField
|
||||
value={model.label}
|
||||
onChange={(event) => {
|
||||
|
@ -37,7 +39,7 @@ export function CorrectionModelInput(props: CorrectionModelInputProps) {
|
|||
onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
format: selection.target.value as 'safetensors',
|
||||
format: selection.target.value as ModelFormat,
|
||||
});
|
||||
}}
|
||||
>
|
||||
|
@ -50,12 +52,13 @@ export function CorrectionModelInput(props: CorrectionModelInputProps) {
|
|||
onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
model: selection.target.value as 'codeformer',
|
||||
model: selection.target.value as CorrectionArch,
|
||||
});
|
||||
}}
|
||||
>
|
||||
<MenuItem value='codeformer'>Codeformer</MenuItem>
|
||||
<MenuItem value='gfpgan'>GFPGAN</MenuItem>
|
||||
</Select>
|
||||
<Button onClick={() => onRemove(model)}>Remove</Button>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
import { MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
|
||||
import { DiffusionModel } from '../../../types.js';
|
||||
import { DiffusionModel, ModelFormat } from '../../../types.js';
|
||||
|
||||
export interface DiffusionModelInputProps {
|
||||
key?: number | string;
|
||||
model: DiffusionModel;
|
||||
|
||||
onChange: (model: DiffusionModel) => void;
|
||||
onRemove: (model: DiffusionModel) => void;
|
||||
}
|
||||
|
||||
export function DiffusionModelInput(props: DiffusionModelInputProps) {
|
||||
const { model, onChange } = props;
|
||||
const { key, model, onChange, onRemove } = props;
|
||||
|
||||
return <Stack direction='row' spacing={2}>
|
||||
return <Stack direction='row' spacing={2} key={key}>
|
||||
<TextField label='Label' value={model.label} onChange={(event) => {
|
||||
onChange({
|
||||
...model,
|
||||
|
@ -28,11 +30,12 @@ export function DiffusionModelInput(props: DiffusionModelInputProps) {
|
|||
<Select value={model.format} label='Format' onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
format: selection.target.value as 'ckpt' | 'safetensors',
|
||||
format: selection.target.value as ModelFormat,
|
||||
});
|
||||
}}>
|
||||
<MenuItem value='ckpt'>ckpt</MenuItem>
|
||||
<MenuItem value='safetensors'>safetensors</MenuItem>
|
||||
</Select>
|
||||
<Button onClick={() => onRemove(model)}>Remove</Button>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
import { MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
|
||||
import { ExtraNetwork } from '../../../types.js';
|
||||
import { ExtraNetwork, ModelFormat, NetworkModel, NetworkType } from '../../../types.js';
|
||||
|
||||
export interface ExtraNetworkInputProps {
|
||||
key?: number | string;
|
||||
model: ExtraNetwork;
|
||||
|
||||
onChange: (model: ExtraNetwork) => void;
|
||||
onRemove: (model: ExtraNetwork) => void;
|
||||
}
|
||||
|
||||
export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
|
||||
const { model, onChange } = props;
|
||||
const { key, model, onChange, onRemove } = props;
|
||||
|
||||
return <Stack direction='row' spacing={2}>
|
||||
return <Stack direction='row' spacing={2} key={key}>
|
||||
<TextField
|
||||
label='Label'
|
||||
value={model.label}
|
||||
|
@ -39,7 +41,7 @@ export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
|
|||
onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
format: selection.target.value as 'safetensors',
|
||||
format: selection.target.value as ModelFormat,
|
||||
});
|
||||
}}
|
||||
>
|
||||
|
@ -50,7 +52,7 @@ export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
|
|||
<Select value={model.type} label='Type' onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
type: selection.target.value as 'lora',
|
||||
type: selection.target.value as NetworkType,
|
||||
});
|
||||
}}>
|
||||
<MenuItem value='inversion'>Textual Inversion</MenuItem>
|
||||
|
@ -59,12 +61,13 @@ export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
|
|||
<Select value={model.model} label='Model' onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
model: selection.target.value as 'sd-scripts',
|
||||
model: selection.target.value as NetworkModel,
|
||||
});
|
||||
}}>
|
||||
<MenuItem value='sd-scripts'>LoRA - sd-scripts</MenuItem>
|
||||
<MenuItem value='concept'>TI - concept</MenuItem>
|
||||
<MenuItem value='embeddings'>TI - embeddings</MenuItem>
|
||||
</Select>
|
||||
<Button onClick={() => onRemove(model)}>Remove</Button>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -1,18 +1,20 @@
|
|||
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
|
||||
import { ExtraSource } from '../../../types';
|
||||
import { AnyFormat, ExtraSource } from '../../../types.js';
|
||||
|
||||
export interface ExtraSourceInputProps {
|
||||
key?: number | string;
|
||||
model: ExtraSource;
|
||||
|
||||
onChange: (model: ExtraSource) => void;
|
||||
onRemove: (model: ExtraSource) => void;
|
||||
}
|
||||
|
||||
export function ExtraSourceInput(props: ExtraSourceInputProps) {
|
||||
const { model, onChange } = props;
|
||||
const { key, model, onChange, onRemove } = props;
|
||||
|
||||
return <Stack direction='row' spacing={2}>
|
||||
return <Stack direction='row' spacing={2} key={key}>
|
||||
<TextField label='Name' value={model.name} onChange={(event) => {
|
||||
onChange({
|
||||
...model,
|
||||
|
@ -31,7 +33,7 @@ export function ExtraSourceInput(props: ExtraSourceInputProps) {
|
|||
onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
format: selection.target.value as 'safetensors',
|
||||
format: selection.target.value as AnyFormat,
|
||||
});
|
||||
}}
|
||||
>
|
||||
|
@ -46,5 +48,6 @@ export function ExtraSourceInput(props: ExtraSourceInputProps) {
|
|||
dest: event.target.value,
|
||||
});
|
||||
}} />
|
||||
<Button onClick={() => onRemove(model)}>Remove</Button>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -1,19 +1,21 @@
|
|||
import { MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
|
||||
import { UpscalingModel } from '../../../types.js';
|
||||
import { ModelFormat, UpscalingArch, UpscalingModel } from '../../../types.js';
|
||||
import { NumericField } from '../NumericField.js';
|
||||
|
||||
export interface UpscalingModelInputProps {
|
||||
key?: number | string;
|
||||
model: UpscalingModel;
|
||||
|
||||
onChange: (model: UpscalingModel) => void;
|
||||
onRemove: (model: UpscalingModel) => void;
|
||||
}
|
||||
|
||||
export function UpscalingModelInput(props: UpscalingModelInputProps) {
|
||||
const { model, onChange } = props;
|
||||
const { key, model, onChange, onRemove } = props;
|
||||
|
||||
return <Stack direction='row' spacing={2}>
|
||||
return <Stack direction='row' spacing={2} key={key}>
|
||||
<TextField value={model.label} label='Label' onChange={(event) => {
|
||||
onChange({
|
||||
...model,
|
||||
|
@ -29,7 +31,7 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) {
|
|||
<Select value={model.format} label='Format' onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
format: selection.target.value as 'ckpt',
|
||||
format: selection.target.value as ModelFormat,
|
||||
});
|
||||
}}>
|
||||
<MenuItem value='ckpt'>ckpt</MenuItem>
|
||||
|
@ -38,7 +40,7 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) {
|
|||
<Select value={model.model} label='Type' onChange={(selection) => {
|
||||
onChange({
|
||||
...model,
|
||||
model: selection.target.value as 'bsrgan',
|
||||
model: selection.target.value as UpscalingArch,
|
||||
});
|
||||
}}>
|
||||
<MenuItem value='bsrgan'>BSRGAN</MenuItem>
|
||||
|
@ -58,5 +60,6 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<Button onClick={() => onRemove(model)}>Remove</Button>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import { useStore } from 'zustand';
|
|||
|
||||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, SafetensorFormat, UpscalingModel } from '../../types.js';
|
||||
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, NetworkModel, NetworkType, SafetensorFormat, UpscalingModel } from '../../types.js';
|
||||
import { EditableList } from '../input/EditableList';
|
||||
import { CorrectionModelInput } from '../input/model/CorrectionModel.js';
|
||||
import { DiffusionModelInput } from '../input/model/DiffusionModel.js';
|
||||
|
@ -79,6 +79,16 @@ export function Models() {
|
|||
const setExtraSource = useStore(state, (s) => s.setExtraSource);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setUpscalingModel = useStore(state, (s) => s.setUpscalingModel);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const removeCorrectionModel = useStore(state, (s) => s.removeCorrectionModel);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const removeDiffusionModel = useStore(state, (s) => s.removeDiffusionModel);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const removeExtraNetwork = useStore(state, (s) => s.removeExtraNetwork);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const removeExtraSource = useStore(state, (s) => s.removeExtraSource);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const removeUpscalingModel = useStore(state, (s) => s.removeUpscalingModel);
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
||||
const result = useQuery(['extras'], async () => client.extras(), {
|
||||
|
@ -123,7 +133,7 @@ export function Models() {
|
|||
name: kebabCase(l),
|
||||
source: s,
|
||||
})}
|
||||
// removeItem={(m) => { /* TODO */ }}
|
||||
removeItem={(m) => removeDiffusionModel(m)}
|
||||
renderItem={DiffusionModelInput}
|
||||
setItem={(model) => setDiffusionModel(model)}
|
||||
/>
|
||||
|
@ -142,6 +152,7 @@ export function Models() {
|
|||
name: kebabCase(l),
|
||||
source: s,
|
||||
})}
|
||||
removeItem={(m) => removeCorrectionModel(m)}
|
||||
renderItem={CorrectionModelInput}
|
||||
setItem={(model) => setCorrectionModel(model)}
|
||||
/>
|
||||
|
@ -161,6 +172,7 @@ export function Models() {
|
|||
scale: 4,
|
||||
source: s,
|
||||
})}
|
||||
removeItem={(m) => removeUpscalingModel(m)}
|
||||
renderItem={UpscalingModelInput}
|
||||
setItem={(model) => setUpscalingModel(model)}
|
||||
/>
|
||||
|
@ -176,11 +188,12 @@ export function Models() {
|
|||
newItem={(l, s) => ({
|
||||
format: 'safetensors' as SafetensorFormat,
|
||||
label: l,
|
||||
model: 'embeddings' as const,
|
||||
model: 'embeddings' as NetworkModel,
|
||||
name: kebabCase(l),
|
||||
source: s,
|
||||
type: 'inversion' as const,
|
||||
type: 'inversion' as NetworkType,
|
||||
})}
|
||||
removeItem={(m) => removeExtraNetwork(m)}
|
||||
renderItem={ExtraNetworkInput}
|
||||
setItem={(model) => setExtraNetwork(model)}
|
||||
/>
|
||||
|
@ -199,6 +212,7 @@ export function Models() {
|
|||
name: kebabCase(l),
|
||||
source: s,
|
||||
})}
|
||||
removeItem={(m) => removeExtraSource(m)}
|
||||
renderItem={ExtraSourceInput}
|
||||
setItem={(model) => setExtraSource(model)}
|
||||
/>
|
||||
|
|
|
@ -65,6 +65,12 @@ interface ExtraSlice {
|
|||
setExtraNetwork(model: ExtraNetwork): void;
|
||||
setExtraSource(model: ExtraSource): void;
|
||||
setUpscalingModel(model: UpscalingModel): void;
|
||||
|
||||
removeCorrectionModel(model: CorrectionModel): void;
|
||||
removeDiffusionModel(model: DiffusionModel): void;
|
||||
removeExtraNetwork(model: ExtraNetwork): void;
|
||||
removeExtraSource(model: ExtraSource): void;
|
||||
removeUpscalingModel(model: UpscalingModel): void;
|
||||
}
|
||||
|
||||
interface HistorySlice {
|
||||
|
@ -679,6 +685,70 @@ export function createStateSlices(server: ServerParams) {
|
|||
};
|
||||
});
|
||||
},
|
||||
removeCorrectionModel(model) {
|
||||
set((prev) => {
|
||||
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
correction,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeDiffusionModel(model) {
|
||||
set((prev) => {
|
||||
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
diffusion,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeExtraNetwork(model) {
|
||||
set((prev) => {
|
||||
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
networks,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeExtraSource(model) {
|
||||
set((prev) => {
|
||||
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
sources,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeUpscalingModel(model) {
|
||||
set((prev) => {
|
||||
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
upscaling,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
|
|
|
@ -1,6 +1,16 @@
|
|||
export type TorchFormat = 'bin' | 'ckpt' | 'pt' | 'pth';
|
||||
export type OnnxFormat = 'onnx';
|
||||
export type SafetensorFormat = 'safetensors';
|
||||
export type TensorFormat = TorchFormat | SafetensorFormat;
|
||||
export type ModelFormat = TensorFormat | OnnxFormat;
|
||||
export type MarkupFormat = 'json' | 'yaml';
|
||||
export type AnyFormat = MarkupFormat | ModelFormat;
|
||||
|
||||
export type UpscalingArch = 'bsrgan' | 'resrgan' | 'swinir';
|
||||
export type CorrectionArch = 'codeformer' | 'gfpgan';
|
||||
|
||||
export type NetworkType = 'inversion' | 'lora';
|
||||
export type NetworkModel = 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts';
|
||||
|
||||
export interface BaseModel {
|
||||
/**
|
||||
|
@ -35,17 +45,17 @@ export interface DiffusionModel extends BaseModel {
|
|||
}
|
||||
|
||||
export interface UpscalingModel extends BaseModel {
|
||||
model?: 'bsrgan' | 'resrgan' | 'swinir';
|
||||
model?: UpscalingArch;
|
||||
scale: number;
|
||||
}
|
||||
|
||||
export interface CorrectionModel extends BaseModel {
|
||||
model?: 'codeformer' | 'gfpgan';
|
||||
model?: CorrectionArch;
|
||||
}
|
||||
|
||||
export interface ExtraNetwork extends BaseModel {
|
||||
model: 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts';
|
||||
type: 'inversion' | 'lora';
|
||||
model: NetworkModel;
|
||||
type: NetworkType;
|
||||
}
|
||||
|
||||
export interface ExtraSource {
|
||||
|
|
Loading…
Reference in New Issue