1
0
Fork 0

clean up react hooks in models tab

This commit is contained in:
Sean Sube 2023-05-07 23:02:56 -05:00
parent a0fdffab23
commit 1e66ffd252
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
13 changed files with 468 additions and 94 deletions

View File

@ -1,5 +1,5 @@
/* eslint-disable max-lines */
import { doesExist, InvalidArgumentError } from '@apextoaster/js-utils';
import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
import { ServerParams } from '../config.js';
import { range } from '../utils.js';
@ -134,7 +134,7 @@ export function appendHighresToURL(url: URL, highres: HighresParams) {
/**
* Make an API client using the given API root and fetch client.
*/
export function makeClient(root: string, f = fetch): ApiClient {
export function makeClient(root: string, token: Maybe<string> = undefined, f = fetch): ApiClient {
function parseRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
return f(url, options).then((res) => parseApiResponse(root, res));
}
@ -142,12 +142,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
return {
async extras(): Promise<ExtrasFile> {
const path = makeApiUrl(root, 'extras');
if (doesExist(token)) {
path.searchParams.append('token', token);
}
const res = await f(path);
return await res.json() as ExtrasFile;
},
async writeExtras(extras: ExtrasFile): Promise<WriteExtrasResponse> {
const path = makeApiUrl(root, 'extras');
if (doesExist(token)) {
path.searchParams.append('token', token);
}
const res = await f(path, {
body: JSON.stringify(extras),
method: 'PUT',
@ -454,18 +463,24 @@ export function makeClient(root: string, f = fetch): ApiClient {
throw new InvalidArgumentError('unknown request type');
}
},
async restart(token: string): Promise<boolean> {
async restart(): Promise<boolean> {
const path = makeApiUrl(root, 'restart');
path.searchParams.append('token', token);
if (doesExist(token)) {
path.searchParams.append('token', token);
}
const res = await f(path, {
method: 'POST',
});
return res.status === STATUS_SUCCESS;
},
async status(token: string): Promise<Array<unknown>> {
async status(): Promise<Array<unknown>> {
const path = makeApiUrl(root, 'status');
path.searchParams.append('token', token);
if (doesExist(token)) {
path.searchParams.append('token', token);
}
const res = await f(path);
return res.json();

View File

@ -69,10 +69,10 @@ export const LOCAL_CLIENT = {
async strings() {
return {};
},
async restart(token) {
async restart() {
throw new NoServerError();
},
async status(token) {
async status() {
throw new NoServerError();
}
} as ApiClient;

View File

@ -364,10 +364,10 @@ export interface ApiClient {
/**
* Restart the image job workers.
*/
restart(token: string): Promise<boolean>;
restart(): Promise<boolean>;
/**
* Check the status of the image job workers.
*/
status(token: string): Promise<Array<unknown>>;
status(): Promise<Array<unknown>>;
}

View File

@ -21,12 +21,9 @@ export function ModelControl() {
const setModel = useStore(state, (s) => s.setModel);
const { t } = useTranslation();
// get token from query string
const query = new URLSearchParams(window.location.search);
const token = query.get('token');
const [hash, _setHash] = useHash();
const restart = useMutation(['restart'], async () => client.restart(mustExist(token)));
const restart = useMutation(['restart'], async () => client.restart());
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});

View File

@ -1,28 +1,44 @@
import { mustExist } from '@apextoaster/js-utils';
import { Button, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { useStore } from 'zustand';
const { useState } = React;
import { OnnxState, StateContext } from '../../state';
const { useContext, useState } = React;
export interface EditableListProps<T> {
items: Array<T>;
// items: Array<T>;
selector: (s: OnnxState) => Array<T>;
newItem: (l: string, s: string) => T;
renderItem: (t: T) => React.ReactElement;
setItems: (ts: Array<T>) => void;
// removeItem: (t: T) => void;
renderItem: (props: {
model: T;
onChange: (t: T) => void;
}) => React.ReactElement;
setItem: (t: T) => void;
}
export function EditableList<T>(props: EditableListProps<T>) {
const { items, newItem, renderItem, setItems } = props;
const state = mustExist(useContext(StateContext));
const items = useStore(state, props.selector);
const { newItem, renderItem, setItem } = props;
const [nextLabel, setNextLabel] = useState('');
const [nextSource, setNextSource] = useState('');
return <Stack spacing={2}>
{items.map((it, idx) => <Stack direction='row' key={idx} spacing={2}>
{renderItem(it)}
<Button onClick={() => setItems([
...items.slice(0, idx),
...items.slice(idx + 1, items.length),
])}>Remove</Button>
{items.map((model, idx) => <Stack direction='row' key={idx} spacing={2}>
{renderItem({
model,
onChange(t) {
setItem(t);
},
})}
<Button onClick={() => {
// removeItem(model);
}}>Remove</Button>
</Stack>)}
<Stack direction='row' spacing={2}>
<TextField
@ -38,8 +54,9 @@ export function EditableList<T>(props: EditableListProps<T>) {
onChange={(event) => setNextSource(event.target.value)}
/>
<Button onClick={() => {
setItems([...items, newItem(nextLabel, nextSource)]);
setItem(newItem(nextLabel, nextSource));
setNextLabel('');
setNextSource('');
}}>New</Button>
</Stack>
</Stack>;

View File

@ -1,17 +1,61 @@
import { Stack, TextField } from '@mui/material';
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { CorrectionModel } from '../../../types';
export interface CorrectionModelInputProps {
model: CorrectionModel;
onChange: (model: CorrectionModel) => void;
}
export function CorrectionModelInput(props: CorrectionModelInputProps) {
const { model } = props;
const { model, onChange } = props;
return <Stack direction='row' spacing={2}>
<TextField value={model.label} />
<TextField value={model.source} />
<TextField
value={model.label}
onChange={(event) => {
onChange({
...model,
label: event.target.value,
});
}}
/>
<TextField
value={model.source}
onChange={(event) => {
onChange({
...model,
source: event.target.value,
});
}}
/>
<Select
value={model.format}
label='Format'
onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'safetensors',
});
}}
>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>
<Select
value={model.model}
label='Type'
onChange={(selection) => {
onChange({
...model,
model: selection.target.value as 'codeformer',
});
}}
>
<MenuItem value='codeformer'>Codeformer</MenuItem>
<MenuItem value='gfpgan'>GFPGAN</MenuItem>
</Select>
</Stack>;
}

View File

@ -1,19 +1,36 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { DiffusionModel } from '../../../types';
import { DiffusionModel } from '../../../types.js';
export interface DiffusionModelInputProps {
model: DiffusionModel;
onChange: (model: DiffusionModel) => void;
}
export function DiffusionModelInput(props: DiffusionModelInputProps) {
const { model } = props;
const { model, onChange } = props;
return <Stack direction='row' spacing={2}>
<TextField label='Label' value={model.label} />
<TextField label='Source' value={model.source} />
<Select value={model.format} label='Format'>
<TextField label='Label' value={model.label} onChange={(event) => {
onChange({
...model,
label: event.target.value,
});
}} />
<TextField label='Source' value={model.source} onChange={(event) => {
onChange({
...model,
source: event.target.value,
});
}} />
<Select value={model.format} label='Format' onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'ckpt' | 'safetensors',
});
}}>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>

View File

@ -1,23 +1,66 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { ExtraNetwork } from '../../../types';
import { ExtraNetwork } from '../../../types.js';
export interface ExtraNetworkInputProps {
model: ExtraNetwork;
onChange: (model: ExtraNetwork) => void;
}
export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
const { model } = props;
const { model, onChange } = props;
return <Stack direction='row' spacing={2}>
<TextField value={model.label} label='Label' />
<TextField value={model.source} label='Source' />
<Select value={model.type} label='Type'>
<TextField
label='Label'
value={model.label}
onChange={(event) => {
onChange({
...model,
label: event.target.value,
});
}}
/>
<TextField
label='Source'
value={model.source}
onChange={(event) => {
onChange({
...model,
source: event.target.value,
});
}}
/>
<Select
label='Format'
value={model.format}
onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'safetensors',
});
}}
>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>
<Select value={model.type} label='Type' onChange={(selection) => {
onChange({
...model,
type: selection.target.value as 'lora',
});
}}>
<MenuItem value='inversion'>Textual Inversion</MenuItem>
<MenuItem value='lora'>LoRA or LyCORIS</MenuItem>
</Select>
<Select value={model.model} label='Model'>
<Select value={model.model} label='Model' onChange={(selection) => {
onChange({
...model,
model: selection.target.value as 'sd-scripts',
});
}}>
<MenuItem value='sd-scripts'>LoRA - sd-scripts</MenuItem>
<MenuItem value='concept'>TI - concept</MenuItem>
<MenuItem value='embeddings'>TI - embeddings</MenuItem>

View File

@ -1,17 +1,48 @@
import * as React from 'react';
import { Stack, TextField } from '@mui/material';
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import { ExtraSource } from '../../../types';
export interface ExtraSourceInputProps {
model: ExtraSource;
onChange: (model: ExtraSource) => void;
}
export function ExtraSourceInput(props: ExtraSourceInputProps) {
const { model } = props;
const { model, onChange } = props;
return <Stack direction='row' spacing={2}>
<TextField label='dest' value={model.dest} />
<TextField label='source' value={model.source} />
<TextField label='Name' value={model.name} onChange={(event) => {
onChange({
...model,
name: event.target.value,
});
}} />
<TextField label='Source' value={model.source} onChange={(event) => {
onChange({
...model,
source: event.target.value,
});
}} />
<Select
label='Format'
value={model.format}
onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'safetensors',
});
}}
>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>
<TextField label='Folder' value={model.dest} onChange={(event) => {
onChange({
...model,
dest: event.target.value,
});
}} />
</Stack>;
}

View File

@ -1,17 +1,62 @@
import { Stack, TextField } from '@mui/material';
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { UpscalingModel } from '../../../types.js';
import { NumericField } from '../NumericField.js';
export interface UpscalingModelInputProps {
model: UpscalingModel;
onChange: (model: UpscalingModel) => void;
}
export function UpscalingModelInput(props: UpscalingModelInputProps) {
const { model } = props;
const { model, onChange } = props;
return <Stack direction='row' spacing={2}>
<TextField value={model.label} />
<TextField value={model.source} />
<TextField value={model.label} label='Label' onChange={(event) => {
onChange({
...model,
label: event.target.value,
});
}} />
<TextField value={model.source} label='Source' onChange={(event) => {
onChange({
...model,
source: event.target.value,
});
}} />
<Select value={model.format} label='Format' onChange={(selection) => {
onChange({
...model,
format: selection.target.value as 'ckpt',
});
}}>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>
<Select value={model.model} label='Type' onChange={(selection) => {
onChange({
...model,
model: selection.target.value as 'bsrgan',
});
}}>
<MenuItem value='bsrgan'>BSRGAN</MenuItem>
<MenuItem value='resrgan'>Real ESRGAN</MenuItem>
<MenuItem value='swinir'>SwinIR</MenuItem>
</Select>
<NumericField
label='Scale'
min={1}
max={4}
step={1}
value={model.scale}
onChange={(value) => {
onChange({
...model,
scale: value,
});
}}
/>
</Stack>;
}

View File

@ -1,12 +1,13 @@
import { mustExist } from '@apextoaster/js-utils';
import { Accordion, AccordionDetails, AccordionSummary, Button, Stack } from '@mui/material';
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Accordion, AccordionDetails, AccordionSummary, Alert, Button, CircularProgress, Stack } from '@mui/material';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import _ from 'lodash';
import * as React from 'react';
import { useStore } from 'zustand';
import { ClientContext, StateContext } from '../../state.js';
import { SafetensorFormat } from '../../types.js';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, SafetensorFormat, UpscalingModel } from '../../types.js';
import { EditableList } from '../input/EditableList';
import { CorrectionModelInput } from '../input/model/CorrectionModel.js';
import { DiffusionModelInput } from '../input/model/DiffusionModel.js';
@ -14,49 +15,117 @@ import { ExtraNetworkInput } from '../input/model/ExtraNetwork.js';
import { ExtraSourceInput } from '../input/model/ExtraSource.js';
import { UpscalingModelInput } from '../input/model/UpscalingModel.js';
const { useContext } = React;
const { useContext, useEffect } = React;
// eslint-disable-next-line @typescript-eslint/unbound-method
const { kebabCase } = _;
function mergeModelLists<T extends DiffusionModel | ExtraSource>(local: Array<T>, server: Array<T>) {
const localNames = new Set(local.map((it) => it.name));
const merged = [...local];
for (const model of server) {
if (localNames.has(model.name) === false) {
merged.push(model);
}
}
return merged;
}
function mergeModels(local: ExtrasFile, server: ExtrasFile): ExtrasFile {
const merged: ExtrasFile = {
...server,
correction: mergeModelLists(local.correction, server.correction),
diffusion: mergeModelLists(local.diffusion, server.diffusion),
networks: mergeModelLists(local.networks, server.networks),
sources: mergeModelLists(local.sources, server.sources),
upscaling: mergeModelLists(local.upscaling, server.upscaling),
};
return merged;
}
function selectDiffusionModels(state: OnnxState): Array<DiffusionModel> {
return state.extras.diffusion;
}
function selectCorrectionModels(state: OnnxState): Array<CorrectionModel> {
return state.extras.correction;
}
function selectUpscalingModels(state: OnnxState): Array<UpscalingModel> {
return state.extras.upscaling;
}
function selectExtraNetworks(state: OnnxState): Array<ExtraNetwork> {
return state.extras.networks;
}
function selectExtraSources(state: OnnxState): Array<ExtraSource> {
return state.extras.sources;
}
export function Models() {
const state = mustExist(React.useContext(StateContext));
const extras = useStore(state, (s) => s.extras);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setExtras = useStore(state, (s) => s.setExtras);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setCorrectionModel = useStore(state, (s) => s.setCorrectionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setDiffusionModel = useStore(state, (s) => s.setDiffusionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setExtraNetwork = useStore(state, (s) => s.setExtraNetwork);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setExtraSource = useStore(state, (s) => s.setExtraSource);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscalingModel = useStore(state, (s) => s.setUpscalingModel);
const client = mustExist(useContext(ClientContext));
const serverExtras = useQuery(['extras'], async () => client.extras(), {
// staleTime: STALE_TIME,
const result = useQuery(['extras'], async () => client.extras(), {
staleTime: STALE_TIME,
});
async function writeExtras() {
const resp = await client.writeExtras(extras);
}
const query = useQueryClient();
const write = useMutation(writeExtras, {
onSuccess: () => query.invalidateQueries([ 'extras' ]),
});
useEffect(() => {
if (result.status === 'success' && doesExist(result.data)) {
setExtras(mergeModels(state.getState().extras, result.data));
}
}, [result.status]);
if (result.status === 'error') {
return <Alert severity='error'>Error</Alert>;
}
if (result.status === 'loading') {
return <CircularProgress />;
}
async function writeExtras() {
const resp = await client.writeExtras(state.getState().extras);
// TODO: do something with resp
}
return <Stack spacing={2}>
<Accordion>
<AccordionSummary>
Diffusion Models
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.diffusion}
<EditableList<DiffusionModel>
selector={selectDiffusionModels}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <DiffusionModelInput model={t}/>}
setItems={(diffusion) => setExtras({
...extras,
diffusion,
})}
// removeItem={(m) => { /* TODO */ }}
renderItem={DiffusionModelInput}
setItem={(model) => setDiffusionModel(model)}
/>
</AccordionDetails>
</Accordion>
@ -66,18 +135,15 @@ export function Models() {
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.correction}
selector={selectCorrectionModels}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <CorrectionModelInput model={t}/>}
setItems={(correction) => setExtras({
...extras,
correction,
})}
renderItem={CorrectionModelInput}
setItem={(model) => setCorrectionModel(model)}
/>
</AccordionDetails>
</Accordion>
@ -87,7 +153,7 @@ export function Models() {
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.upscaling}
selector={selectUpscalingModels}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
@ -95,11 +161,8 @@ export function Models() {
scale: 4,
source: s,
})}
renderItem={(t) => <UpscalingModelInput model={t}/>}
setItems={(upscaling) => setExtras({
...extras,
upscaling,
})}
renderItem={UpscalingModelInput}
setItem={(model) => setUpscalingModel(model)}
/>
</AccordionDetails>
</Accordion>
@ -109,7 +172,7 @@ export function Models() {
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.networks}
selector={selectExtraNetworks}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
@ -118,11 +181,8 @@ export function Models() {
source: s,
type: 'inversion' as const,
})}
renderItem={(t) => <ExtraNetworkInput model={t}/>}
setItems={(networks) => setExtras({
...extras,
networks,
})}
renderItem={ExtraNetworkInput}
setItem={(model) => setExtraNetwork(model)}
/>
</AccordionDetails>
</Accordion>
@ -132,18 +192,15 @@ export function Models() {
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.sources}
selector={selectExtraSources}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <ExtraSourceInput model={t}/>}
setItems={(sources) => setExtras({
...extras,
sources,
})}
renderItem={ExtraSourceInput}
setItem={(model) => setExtraSource(model)}
/>
</AccordionDetails>
</Accordion>

View File

@ -143,9 +143,13 @@ export async function main() {
// load config from GUI server
const config = await loadConfig();
// get token from query string
const query = new URLSearchParams(window.location.search);
const token = query.get('token');
// use that to create an API client
const root = getApiRoot(config);
const client = makeClient(root);
const client = makeClient(root, token);
// prep react-dom
const appElement = mustExist(document.getElementById('app'));

View File

@ -24,7 +24,9 @@ import {
UpscaleReqParams,
} from './client/types.js';
import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
import { ExtrasFile } from './types.js';
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types.js';
export const MISSING_INDEX = -1;
export type Theme = PaletteMode | ''; // tri-state, '' is unset
@ -57,6 +59,12 @@ interface ExtraSlice {
extras: ExtrasFile;
setExtras(extras: Partial<ExtrasFile>): void;
setCorrectionModel(model: CorrectionModel): void;
setDiffusionModel(model: DiffusionModel): void;
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
}
interface HistorySlice {
@ -558,6 +566,7 @@ export function createStateSlices(server: ServerParams) {
},
});
// eslint-disable-next-line sonarjs/cognitive-complexity
const createExtraSlice: Slice<ExtraSlice> = (set) => ({
extras: {
correction: [],
@ -575,6 +584,101 @@ export function createStateSlices(server: ServerParams) {
},
}));
},
setCorrectionModel(model) {
set((prev) => {
const correction = [...prev.extras.correction];
const exists = correction.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
correction.push(model);
} else {
correction[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
setDiffusionModel(model) {
set((prev) => {
const diffusion = [...prev.extras.diffusion];
const exists = diffusion.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
diffusion.push(model);
} else {
diffusion[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
setExtraNetwork(model) {
set((prev) => {
const networks = [...prev.extras.networks];
const exists = networks.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
networks.push(model);
} else {
networks[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
setExtraSource(model) {
set((prev) => {
const sources = [...prev.extras.sources];
const exists = sources.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
sources.push(model);
} else {
sources[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
setUpscalingModel(model) {
set((prev) => {
const upscaling = [...prev.extras.upscaling];
const exists = upscaling.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
upscaling.push(model);
} else {
upscaling[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
});
return {