1
0
Fork 0

optimize editable lists, add arch types

This commit is contained in:
Sean Sube 2023-05-08 21:41:15 -05:00
parent 534fc70e29
commit 1566ceef7f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 154 additions and 46 deletions

View File

@ -3,25 +3,25 @@ import { Button, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { useStore } from 'zustand';
import { OnnxState, StateContext } from '../../state';
import { OnnxState, StateContext } from '../../state.js';
const { useContext, useState, memo, useMemo } = React;
export interface EditableListProps<T> {
// items: Array<T>;
selector: (s: OnnxState) => Array<T>;
newItem: (l: string, s: string) => T;
// removeItem: (t: T) => void;
removeItem: (t: T) => void;
renderItem: (props: {
model: T;
onChange: (t: T) => void;
onRemove: (t: T) => void;
}) => React.ReactElement;
setItem: (t: T) => void;
}
export function EditableList<T>(props: EditableListProps<T>) {
const { newItem, renderItem, setItem, selector } = props;
const { newItem, removeItem, renderItem, setItem, selector } = props;
const state = mustExist(useContext(StateContext));
const items = useStore(state, selector);
@ -30,15 +30,14 @@ export function EditableList<T>(props: EditableListProps<T>) {
const RenderMemo = useMemo(() => memo(renderItem), [renderItem]);
return <Stack spacing={2}>
{items.map((model, idx) => <Stack direction='row' key={idx} spacing={2}>
{items.map((model, idx) =>
<RenderMemo
key={idx}
model={model}
onChange={setItem}
onRemove={removeItem}
/>
<Button onClick={() => {
// removeItem(model);
}}>Remove</Button>
</Stack>)}
)}
<Stack direction='row' spacing={2}>
<TextField
label='Label'

View File

@ -1,18 +1,20 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { CorrectionModel } from '../../../types';
import { CorrectionArch, CorrectionModel, ModelFormat } from '../../../types.js';
export interface CorrectionModelInputProps {
key?: number | string;
model: CorrectionModel;
onChange: (model: CorrectionModel) => void;
onRemove: (model: CorrectionModel) => void;
}
export function CorrectionModelInput(props: CorrectionModelInputProps) {
const { model, onChange } = props;
const { key, model, onChange, onRemove } = props;
return <Stack direction='row' spacing={2}>
return <Stack direction='row' spacing={2} key={key}>
<TextField
value={model.label}
onChange={(event) => {
@ -37,7 +39,7 @@ export function CorrectionModelInput(props: CorrectionModelInputProps) {
onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'safetensors',
format: selection.target.value as ModelFormat,
});
}}
>
@ -50,12 +52,13 @@ export function CorrectionModelInput(props: CorrectionModelInputProps) {
onChange={(selection) => {
onChange({
...model,
model: selection.target.value as 'codeformer',
model: selection.target.value as CorrectionArch,
});
}}
>
<MenuItem value='codeformer'>Codeformer</MenuItem>
<MenuItem value='gfpgan'>GFPGAN</MenuItem>
</Select>
<Button onClick={() => onRemove(model)}>Remove</Button>
</Stack>;
}

View File

@ -1,18 +1,20 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { DiffusionModel } from '../../../types.js';
import { DiffusionModel, ModelFormat } from '../../../types.js';
export interface DiffusionModelInputProps {
key?: number | string;
model: DiffusionModel;
onChange: (model: DiffusionModel) => void;
onRemove: (model: DiffusionModel) => void;
}
export function DiffusionModelInput(props: DiffusionModelInputProps) {
const { model, onChange } = props;
const { key, model, onChange, onRemove } = props;
return <Stack direction='row' spacing={2}>
return <Stack direction='row' spacing={2} key={key}>
<TextField label='Label' value={model.label} onChange={(event) => {
onChange({
...model,
@ -28,11 +30,12 @@ export function DiffusionModelInput(props: DiffusionModelInputProps) {
<Select value={model.format} label='Format' onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'ckpt' | 'safetensors',
format: selection.target.value as ModelFormat,
});
}}>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>
<Button onClick={() => onRemove(model)}>Remove</Button>
</Stack>;
}

View File

@ -1,18 +1,20 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { ExtraNetwork } from '../../../types.js';
import { ExtraNetwork, ModelFormat, NetworkModel, NetworkType } from '../../../types.js';
export interface ExtraNetworkInputProps {
key?: number | string;
model: ExtraNetwork;
onChange: (model: ExtraNetwork) => void;
onRemove: (model: ExtraNetwork) => void;
}
export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
const { model, onChange } = props;
const { key, model, onChange, onRemove } = props;
return <Stack direction='row' spacing={2}>
return <Stack direction='row' spacing={2} key={key}>
<TextField
label='Label'
value={model.label}
@ -39,7 +41,7 @@ export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'safetensors',
format: selection.target.value as ModelFormat,
});
}}
>
@ -50,7 +52,7 @@ export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
<Select value={model.type} label='Type' onChange={(selection) => {
onChange({
...model,
type: selection.target.value as 'lora',
type: selection.target.value as NetworkType,
});
}}>
<MenuItem value='inversion'>Textual Inversion</MenuItem>
@ -59,12 +61,13 @@ export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
<Select value={model.model} label='Model' onChange={(selection) => {
onChange({
...model,
model: selection.target.value as 'sd-scripts',
model: selection.target.value as NetworkModel,
});
}}>
<MenuItem value='sd-scripts'>LoRA - sd-scripts</MenuItem>
<MenuItem value='concept'>TI - concept</MenuItem>
<MenuItem value='embeddings'>TI - embeddings</MenuItem>
</Select>
<Button onClick={() => onRemove(model)}>Remove</Button>
</Stack>;
}

View File

@ -1,18 +1,20 @@
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import { ExtraSource } from '../../../types';
import { AnyFormat, ExtraSource } from '../../../types.js';
export interface ExtraSourceInputProps {
key?: number | string;
model: ExtraSource;
onChange: (model: ExtraSource) => void;
onRemove: (model: ExtraSource) => void;
}
export function ExtraSourceInput(props: ExtraSourceInputProps) {
const { model, onChange } = props;
const { key, model, onChange, onRemove } = props;
return <Stack direction='row' spacing={2}>
return <Stack direction='row' spacing={2} key={key}>
<TextField label='Name' value={model.name} onChange={(event) => {
onChange({
...model,
@ -31,7 +33,7 @@ export function ExtraSourceInput(props: ExtraSourceInputProps) {
onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'safetensors',
format: selection.target.value as AnyFormat,
});
}}
>
@ -46,5 +48,6 @@ export function ExtraSourceInput(props: ExtraSourceInputProps) {
dest: event.target.value,
});
}} />
<Button onClick={() => onRemove(model)}>Remove</Button>
</Stack>;
}

View File

@ -1,19 +1,21 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import { Button, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { UpscalingModel } from '../../../types.js';
import { ModelFormat, UpscalingArch, UpscalingModel } from '../../../types.js';
import { NumericField } from '../NumericField.js';
export interface UpscalingModelInputProps {
key?: number | string;
model: UpscalingModel;
onChange: (model: UpscalingModel) => void;
onRemove: (model: UpscalingModel) => void;
}
export function UpscalingModelInput(props: UpscalingModelInputProps) {
const { model, onChange } = props;
const { key, model, onChange, onRemove } = props;
return <Stack direction='row' spacing={2}>
return <Stack direction='row' spacing={2} key={key}>
<TextField value={model.label} label='Label' onChange={(event) => {
onChange({
...model,
@ -29,7 +31,7 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) {
<Select value={model.format} label='Format' onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'ckpt',
format: selection.target.value as ModelFormat,
});
}}>
<MenuItem value='ckpt'>ckpt</MenuItem>
@ -38,7 +40,7 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) {
<Select value={model.model} label='Type' onChange={(selection) => {
onChange({
...model,
model: selection.target.value as 'bsrgan',
model: selection.target.value as UpscalingArch,
});
}}>
<MenuItem value='bsrgan'>BSRGAN</MenuItem>
@ -58,5 +60,6 @@ export function UpscalingModelInput(props: UpscalingModelInputProps) {
});
}}
/>
<Button onClick={() => onRemove(model)}>Remove</Button>
</Stack>;
}

View File

@ -7,7 +7,7 @@ import { useStore } from 'zustand';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, SafetensorFormat, UpscalingModel } from '../../types.js';
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, NetworkModel, NetworkType, SafetensorFormat, UpscalingModel } from '../../types.js';
import { EditableList } from '../input/EditableList';
import { CorrectionModelInput } from '../input/model/CorrectionModel.js';
import { DiffusionModelInput } from '../input/model/DiffusionModel.js';
@ -79,6 +79,16 @@ export function Models() {
const setExtraSource = useStore(state, (s) => s.setExtraSource);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscalingModel = useStore(state, (s) => s.setUpscalingModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeCorrectionModel = useStore(state, (s) => s.removeCorrectionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeDiffusionModel = useStore(state, (s) => s.removeDiffusionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeExtraNetwork = useStore(state, (s) => s.removeExtraNetwork);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeExtraSource = useStore(state, (s) => s.removeExtraSource);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeUpscalingModel = useStore(state, (s) => s.removeUpscalingModel);
const client = mustExist(useContext(ClientContext));
const result = useQuery(['extras'], async () => client.extras(), {
@ -123,7 +133,7 @@ export function Models() {
name: kebabCase(l),
source: s,
})}
// removeItem={(m) => { /* TODO */ }}
removeItem={(m) => removeDiffusionModel(m)}
renderItem={DiffusionModelInput}
setItem={(model) => setDiffusionModel(model)}
/>
@ -142,6 +152,7 @@ export function Models() {
name: kebabCase(l),
source: s,
})}
removeItem={(m) => removeCorrectionModel(m)}
renderItem={CorrectionModelInput}
setItem={(model) => setCorrectionModel(model)}
/>
@ -161,6 +172,7 @@ export function Models() {
scale: 4,
source: s,
})}
removeItem={(m) => removeUpscalingModel(m)}
renderItem={UpscalingModelInput}
setItem={(model) => setUpscalingModel(model)}
/>
@ -176,11 +188,12 @@ export function Models() {
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
model: 'embeddings' as const,
model: 'embeddings' as NetworkModel,
name: kebabCase(l),
source: s,
type: 'inversion' as const,
type: 'inversion' as NetworkType,
})}
removeItem={(m) => removeExtraNetwork(m)}
renderItem={ExtraNetworkInput}
setItem={(model) => setExtraNetwork(model)}
/>
@ -199,6 +212,7 @@ export function Models() {
name: kebabCase(l),
source: s,
})}
removeItem={(m) => removeExtraSource(m)}
renderItem={ExtraSourceInput}
setItem={(model) => setExtraSource(model)}
/>

View File

@ -65,6 +65,12 @@ interface ExtraSlice {
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
removeCorrectionModel(model: CorrectionModel): void;
removeDiffusionModel(model: DiffusionModel): void;
removeExtraNetwork(model: ExtraNetwork): void;
removeExtraSource(model: ExtraSource): void;
removeUpscalingModel(model: UpscalingModel): void;
}
interface HistorySlice {
@ -679,6 +685,70 @@ export function createStateSlices(server: ServerParams) {
};
});
},
removeCorrectionModel(model) {
set((prev) => {
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
removeDiffusionModel(model) {
set((prev) => {
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
removeExtraNetwork(model) {
set((prev) => {
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
removeExtraSource(model) {
set((prev) => {
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
removeUpscalingModel(model) {
set((prev) => {
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
});
return {

View File

@ -1,6 +1,16 @@
export type TorchFormat = 'bin' | 'ckpt' | 'pt' | 'pth';
export type OnnxFormat = 'onnx';
export type SafetensorFormat = 'safetensors';
export type TensorFormat = TorchFormat | SafetensorFormat;
export type ModelFormat = TensorFormat | OnnxFormat;
export type MarkupFormat = 'json' | 'yaml';
export type AnyFormat = MarkupFormat | ModelFormat;
export type UpscalingArch = 'bsrgan' | 'resrgan' | 'swinir';
export type CorrectionArch = 'codeformer' | 'gfpgan';
export type NetworkType = 'inversion' | 'lora';
export type NetworkModel = 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts';
export interface BaseModel {
/**
@ -35,17 +45,17 @@ export interface DiffusionModel extends BaseModel {
}
export interface UpscalingModel extends BaseModel {
model?: 'bsrgan' | 'resrgan' | 'swinir';
model?: UpscalingArch;
scale: number;
}
export interface CorrectionModel extends BaseModel {
model?: 'codeformer' | 'gfpgan';
model?: CorrectionArch;
}
export interface ExtraNetwork extends BaseModel {
model: 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts';
type: 'inversion' | 'lora';
model: NetworkModel;
type: NetworkType;
}
export interface ExtraSource {