1
0
Fork 0

add more model types to models tab

This commit is contained in:
Sean Sube 2023-05-06 15:38:41 -05:00
parent 6e78f40f09
commit 8c88fcd5fe
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 260 additions and 18 deletions

View File

@ -6,33 +6,40 @@ const { useState } = React;
export interface EditableListProps<T> { export interface EditableListProps<T> {
items: Array<T>; items: Array<T>;
newItem: (s: string) => T; newItem: (l: string, s: string) => T;
renderItem: (t: T) => React.ReactElement; renderItem: (t: T) => React.ReactElement;
setItems: (ts: Array<T>) => void; setItems: (ts: Array<T>) => void;
} }
export function EditableList<T>(props: EditableListProps<T>) { export function EditableList<T>(props: EditableListProps<T>) {
const { items, newItem, renderItem, setItems } = props; const { items, newItem, renderItem, setItems } = props;
const [nextItem, setNextItem] = useState(''); const [nextLabel, setNextLabel] = useState('');
const [nextSource, setNextSource] = useState('');
return <Stack> return <Stack spacing={2}>
{items.map((it, idx) => <Stack direction='row' key={idx}> {items.map((it, idx) => <Stack direction='row' key={idx} spacing={2}>
{renderItem(it)} {renderItem(it)}
<Button onClick={() => setItems([ <Button onClick={() => setItems([
...items.slice(0, idx), ...items.slice(0, idx),
...items.slice(idx + 1, items.length), ...items.slice(idx + 1, items.length),
])}>Remove</Button> ])}>Remove</Button>
</Stack>)} </Stack>)}
<Stack direction='row'> <Stack direction='row' spacing={2}>
<TextField
label='Label'
variant='outlined'
value={nextLabel}
onChange={(event) => setNextLabel(event.target.value)}
/>
<TextField <TextField
label='Source' label='Source'
variant='outlined' variant='outlined'
value={nextItem} value={nextSource}
onChange={(event) => setNextItem(event.target.value)} onChange={(event) => setNextSource(event.target.value)}
/> />
<Button onClick={() => { <Button onClick={() => {
setItems([...items, newItem(nextItem)]); setItems([...items, newItem(nextLabel, nextSource)]);
setNextItem(''); setNextLabel('');
}}>New</Button> }}>New</Button>
</Stack> </Stack>
</Stack>; </Stack>;

View File

@ -0,0 +1,17 @@
import { Stack, TextField } from '@mui/material';
import * as React from 'react';
import { CorrectionModel } from '../../../types';
export interface CorrectionModelInputProps {
model: CorrectionModel;
}
export function CorrectionModelInput(props: CorrectionModelInputProps) {
const { model } = props;
return <Stack direction='row' spacing={2}>
<TextField value={model.label} />
<TextField value={model.source} />
</Stack>;
}

View File

@ -0,0 +1,21 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { DiffusionModel } from '../../../types';
export interface DiffusionModelInputProps {
model: DiffusionModel;
}
export function DiffusionModelInput(props: DiffusionModelInputProps) {
const { model } = props;
return <Stack direction='row' spacing={2}>
<TextField label='Label' value={model.label} />
<TextField label='Source' value={model.source} />
<Select value={model.format} label='Format'>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>
</Stack>;
}

View File

@ -0,0 +1,26 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { ExtraNetwork } from '../../../types';
export interface ExtraNetworkInputProps {
model: ExtraNetwork;
}
export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
const { model } = props;
return <Stack direction='row' spacing={2}>
<TextField value={model.label} label='Label' />
<TextField value={model.source} label='Source' />
<Select value={model.type} label='Type'>
<MenuItem value='inversion'>Textual Inversion</MenuItem>
<MenuItem value='lora'>LoRA or LyCORIS</MenuItem>
</Select>
<Select value={model.model} label='Model'>
<MenuItem value='sd-scripts'>LoRA - sd-scripts</MenuItem>
<MenuItem value='concept'>TI - concept</MenuItem>
<MenuItem value='embeddings'>TI - embeddings</MenuItem>
</Select>
</Stack>;
}

View File

@ -0,0 +1,17 @@
import * as React from 'react';
import { Stack, TextField } from '@mui/material';
import { ExtraSource } from '../../../types';
export interface ExtraSourceInputProps {
model: ExtraSource;
}
export function ExtraSourceInput(props: ExtraSourceInputProps) {
const { model } = props;
return <Stack direction='row' spacing={2}>
<TextField label='dest' value={model.dest} />
<TextField label='source' value={model.source} />
</Stack>;
}

View File

@ -0,0 +1,17 @@
import { Stack, TextField } from '@mui/material';
import * as React from 'react';
import { UpscalingModel } from '../../../types.js';
export interface UpscalingModelInputProps {
model: UpscalingModel;
}
export function UpscalingModelInput(props: UpscalingModelInputProps) {
const { model } = props;
return <Stack direction='row' spacing={2}>
<TextField value={model.label} />
<TextField value={model.source} />
</Stack>;
}

View File

@ -1,10 +1,19 @@
import { mustExist } from '@apextoaster/js-utils'; import { mustExist } from '@apextoaster/js-utils';
import { Accordion, AccordionDetails, AccordionSummary, Button, Stack } from '@mui/material'; import { Accordion, AccordionDetails, AccordionSummary, Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import _ from 'lodash';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { StateContext } from '../../state.js'; import { StateContext } from '../../state.js';
import { EditableList } from '../input/EditableList'; import { EditableList } from '../input/EditableList';
import { DiffusionModelInput } from '../input/model/DiffusionModel.js';
import { SafetensorFormat } from '../../types.js';
import { CorrectionModelInput } from '../input/model/CorrectionModel.js';
import { UpscalingModelInput } from '../input/model/UpscalingModel.js';
import { ExtraSourceInput } from '../input/model/ExtraSource.js';
import { ExtraNetworkInput } from '../input/model/ExtraNetwork.js';
// eslint-disable-next-line @typescript-eslint/unbound-method
const { kebabCase } = _;
export function Models() { export function Models() {
const state = mustExist(React.useContext(StateContext)); const state = mustExist(React.useContext(StateContext));
@ -12,7 +21,7 @@ export function Models() {
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const setExtras = useStore(state, (s) => s.setExtras); const setExtras = useStore(state, (s) => s.setExtras);
return <Stack> return <Stack spacing={2}>
<Accordion> <Accordion>
<AccordionSummary> <AccordionSummary>
Diffusion Models Diffusion Models
@ -20,8 +29,13 @@ export function Models() {
<AccordionDetails> <AccordionDetails>
<EditableList <EditableList
items={extras.diffusion} items={extras.diffusion}
newItem={(s) => s} newItem={(l, s) => ({
renderItem={(t) => <div key={t}>{t}</div>} format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <DiffusionModelInput model={t}/>}
setItems={(diffusion) => setExtras({ setItems={(diffusion) => setExtras({
...extras, ...extras,
diffusion, diffusion,
@ -34,6 +48,20 @@ export function Models() {
Correction Models Correction Models
</AccordionSummary> </AccordionSummary>
<AccordionDetails> <AccordionDetails>
<EditableList
items={extras.correction}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <CorrectionModelInput model={t}/>}
setItems={(correction) => setExtras({
...extras,
correction,
})}
/>
</AccordionDetails> </AccordionDetails>
</Accordion> </Accordion>
<Accordion> <Accordion>
@ -41,13 +69,44 @@ export function Models() {
Upscaling Models Upscaling Models
</AccordionSummary> </AccordionSummary>
<AccordionDetails> <AccordionDetails>
<EditableList
items={extras.upscaling}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
scale: 4,
source: s,
})}
renderItem={(t) => <UpscalingModelInput model={t}/>}
setItems={(upscaling) => setExtras({
...extras,
upscaling,
})}
/>
</AccordionDetails> </AccordionDetails>
</Accordion> </Accordion>
<Accordion> <Accordion>
<AccordionSummary> <AccordionSummary>
Additional Networks Extra Networks
</AccordionSummary> </AccordionSummary>
<AccordionDetails> <AccordionDetails>
<EditableList
items={extras.networks}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
model: 'embeddings' as const,
name: kebabCase(l),
source: s,
type: 'inversion' as const,
})}
renderItem={(t) => <ExtraNetworkInput model={t}/>}
setItems={(networks) => setExtras({
...extras,
networks,
})}
/>
</AccordionDetails> </AccordionDetails>
</Accordion> </Accordion>
<Accordion> <Accordion>
@ -55,6 +114,20 @@ export function Models() {
Other Sources Other Sources
</AccordionSummary> </AccordionSummary>
<AccordionDetails> <AccordionDetails>
<EditableList
items={extras.sources}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <ExtraSourceInput model={t}/>}
setItems={(sources) => setExtras({
...extras,
sources,
})}
/>
</AccordionDetails> </AccordionDetails>
</Accordion> </Accordion>
<Button color='warning'>Save & Convert</Button> <Button color='warning'>Save & Convert</Button>

View File

@ -24,6 +24,7 @@ import {
UpscaleReqParams, UpscaleReqParams,
} from './client/api.js'; } from './client/api.js';
import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js'; import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
import { ExtrasFile } from './types.js';
export type Theme = PaletteMode | ''; // tri-state, '' is unset export type Theme = PaletteMode | ''; // tri-state, '' is unset
@ -38,10 +39,6 @@ interface HistoryItem {
retry: RetryParams; retry: RetryParams;
} }
interface ExtrasFile {
diffusion: Array<string>;
}
interface BrushSlice { interface BrushSlice {
brush: BrushParams; brush: BrushParams;
@ -556,7 +553,6 @@ export function createStateSlices(server: ServerParams) {
next.resetTxt2Img(); next.resetTxt2Img();
next.resetUpscaleTab(); next.resetUpscaleTab();
next.resetBlend(); next.resetBlend();
// TODO: reset more stuff
return next; return next;
}); });
}, },
@ -564,7 +560,11 @@ export function createStateSlices(server: ServerParams) {
const createExtraSlice: Slice<ExtraSlice> = (set) => ({ const createExtraSlice: Slice<ExtraSlice> = (set) => ({
extras: { extras: {
correction: [],
diffusion: [], diffusion: [],
networks: [],
sources: [],
upscaling: [],
}, },
setExtras(extras) { setExtras(extras) {
set((prev) => ({ set((prev) => ({

64
gui/src/types.ts Normal file
View File

@ -0,0 +1,64 @@
export type TorchFormat = 'bin' | 'ckpt' | 'pt' | 'pth';
export type OnnxFormat = 'onnx';
export type SafetensorFormat = 'safetensors';
export interface BaseModel {
/**
* Format of the model, used when downloading files that may not have a format in their URL.
*/
format: OnnxFormat | SafetensorFormat | TorchFormat;
/**
* Localized label of the model.
*/
label: string;
/**
* Filename of the model.
*/
name: string;
/**
* Source URL or local path.
*/
source: string;
}
export interface DiffusionModel extends BaseModel {
config?: string;
image_size?: string;
inversions?: Array<unknown>;
loras?: Array<unknown>;
pipeline?: string;
vae?: string;
version?: string;
}
export interface UpscalingModel extends BaseModel {
model?: 'bsrgan' | 'resrgan' | 'swinir';
scale: number;
}
export interface CorrectionModel extends BaseModel {
model?: 'codeformer' | 'gfpgan';
}
export interface ExtraNetwork extends BaseModel {
model: 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts';
type: 'inversion' | 'lora';
}
export interface ExtraSource {
dest?: string;
format?: string;
name: string;
source: string;
}
export interface ExtrasFile {
correction: Array<CorrectionModel>;
diffusion: Array<DiffusionModel>;
upscaling: Array<UpscalingModel>;
networks: Array<ExtraNetwork>;
sources: Array<ExtraSource>;
}