1
0
Fork 0

fix(gui): break up state into slices for each tab

This commit is contained in:
Sean Sube 2023-01-12 17:49:51 -06:00
parent 35e2e1dda6
commit 689a6a183f
5 changed files with 185 additions and 204 deletions

View File

@ -4,37 +4,32 @@ import { useContext } from 'react';
import * as React from 'react'; import * as React from 'react';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ApiResponse } from '../api/client.js';
import { StateContext } from '../main.js'; import { StateContext } from '../main.js';
import { ImageCard } from './ImageCard.js'; import { ImageCard } from './ImageCard.js';
import { LoadingCard } from './LoadingCard.js'; import { LoadingCard } from './LoadingCard.js';
export function ImageHistory() { export function ImageHistory() {
const history = useStore(mustExist(useContext(StateContext)), (state) => state.history); const history = useStore(mustExist(useContext(StateContext)), (state) => state.history);
const limit = useStore(mustExist(useContext(StateContext)), (state) => state.limit);
const loading = useStore(mustExist(useContext(StateContext)), (state) => state.loading);
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const setHistory = useStore(mustExist(useContext(StateContext)), (state) => state.setHistory); const removeHistory = useStore(mustExist(useContext(StateContext)), (state) => state.removeHistory);
const { images } = history;
const children = []; const children = [];
if (history.loading) { if (loading) {
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
} }
function removeHistory(image: ApiResponse) { if (history.length > 0) {
setHistory(images.filter((item) => image.output !== item.output)); children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
}
if (images.length > 0) {
children.push(...images.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
} else { } else {
if (history.loading === false) { if (loading === false) {
children.push(<div>No results. Press Generate.</div>); children.push(<div>No results. Press Generate.</div>);
} }
} }
const limited = children.slice(0, history.limit); const limited = children.slice(0, limit);
return <Grid container spacing={2}>{limited.map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>; return <Grid container spacing={2}>{limited.map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>;
} }

View File

@ -22,7 +22,7 @@ export function Settings(_props: SettingsProps) {
min={2} min={2}
max={20} max={20}
step={1} step={1}
value={state.history.limit} value={state.limit}
onChange={(value) => state.setLimit(value)} onChange={(value) => state.setLimit(value)}
/> />
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => { <TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {

View File

@ -7,167 +7,26 @@ import { QueryClient, QueryClientProvider } from 'react-query';
import { createStore, StoreApi } from 'zustand'; import { createStore, StoreApi } from 'zustand';
import { createJSONStorage, persist } from 'zustand/middleware'; import { createJSONStorage, persist } from 'zustand/middleware';
import { ApiClient, ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js'; import { ApiClient, makeClient } from './api/client.js';
import { OnnxWeb } from './components/OnnxWeb.js'; import { OnnxWeb } from './components/OnnxWeb.js';
import { ConfigState, loadConfig } from './config.js'; import { loadConfig } from './config.js';
import { createStateSlices, OnnxState } from './state.js';
const { createContext } = React; const { createContext } = React;
interface OnnxState {
defaults: Required<BaseImgParams>;
txt2img: ConfigState<Required<Txt2ImgParams>>;
img2img: ConfigState<Required<Img2ImgParams>>;
inpaint: ConfigState<Required<InpaintParams>>;
setDefaults(newParams: Partial<BaseImgParams>): void;
setTxt2Img(newParams: Partial<ConfigState<Required<Txt2ImgParams>>>): void;
setImg2Img(newParams: Partial<ConfigState<Required<Img2ImgParams>>>): void;
setInpaint(newParams: Partial<ConfigState<Required<InpaintParams>>>): void;
resetTxt2Img(): void;
resetImg2Img(): void;
resetInpaint(): void;
history: {
images: Array<ApiResponse>;
limit: number;
loading: boolean;
};
setLimit(limit: number): void;
setLoading(loading: boolean): void;
setHistory(newHistory: Array<ApiResponse>): void;
pushHistory(newImage: ApiResponse): void;
}
export async function main() { export async function main() {
const config = await loadConfig(); const config = await loadConfig();
const client = makeClient(config.api.root); const client = makeClient(config.api.root);
const params = await client.params(); const params = await client.params();
merge(params, config.params); merge(params, config.params);
const defaults = paramsFromConfig(params); const { createDefaultSlice, createHistorySlice, createImg2ImgSlice, createInpaintSlice, createTxt2ImgSlice } = createStateSlices(params);
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((set) => ({ const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
defaults, ...createTxt2ImgSlice(...slice),
history: { ...createImg2ImgSlice(...slice),
images: [], ...createInpaintSlice(...slice),
limit: 4, ...createHistorySlice(...slice),
loading: false, ...createDefaultSlice(...slice),
},
txt2img: {
...defaults,
height: params.height.default,
width: params.width.default,
},
img2img: {
...defaults,
strength: params.strength.default,
},
inpaint: {
...defaults,
},
setDefaults(newParams) {
set((oldState) => ({
...oldState,
defaults: {
...oldState.defaults,
...newParams,
},
}));
},
setTxt2Img(newParams) {
set((oldState) => ({
...oldState,
txt2img: {
...oldState.txt2img,
...newParams,
},
}));
},
setImg2Img(newParams) {
set((oldState) => ({
...oldState,
img2img: {
...oldState.img2img,
...newParams,
},
}));
},
setInpaint(newParams) {
set((oldState) => ({
...oldState,
inpaint: {
...oldState.inpaint,
...newParams,
},
}));
},
resetTxt2Img() {
set((oldState) => ({
...oldState,
txt2img: {
...defaults,
height: params.height.default,
width: params.width.default,
},
}));
},
resetImg2Img() {
set((oldState) => ({
...oldState,
img2img: {
...defaults,
strength: params.strength.default,
},
}));
},
resetInpaint() {
set((oldState) => ({
...oldState,
inpaint: {
...defaults,
},
}));
},
setLimit(limit) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
limit,
},
}));
},
setLoading(loading) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
loading,
},
}));
},
pushHistory(newImage: ApiResponse) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: [
newImage,
...oldState.history.images,
].slice(0, oldState.history.limit),
},
}));
},
setHistory(newHistory: Array<ApiResponse>) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: newHistory,
},
}));
},
}), { }), {
name: 'onnx-web', name: 'onnx-web',
partialize: (oldState) => ({ partialize: (oldState) => ({
@ -178,6 +37,7 @@ export async function main() {
}, },
}), }),
storage: createJSONStorage(() => localStorage), storage: createJSONStorage(() => localStorage),
version: 1,
})); }));
const query = new QueryClient(); const query = new QueryClient();

View File

@ -1,45 +1,170 @@
import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, Txt2ImgParams } from './api/client.js'; import { createStore, StateCreator } from 'zustand';
import { ConfigState } from './config.js'; import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, paramsFromConfig, Txt2ImgParams } from './api/client.js';
import { ConfigParams, ConfigState } from './config.js';
interface TabState<TabParams extends BaseImgParams> { type TabState<TabParams extends BaseImgParams> = ConfigState<Required<TabParams>>;
params: ConfigState<Required<TabParams>>;
reset(): void; interface Txt2ImgSlice {
update(params: Partial<ConfigState<Required<TabParams>>>): void; txt2img: TabState<Txt2ImgParams>;
setTxt2Img(params: Partial<Txt2ImgParams>): void;
resetTxt2Img(): void;
} }
interface OnnxState { interface Img2ImgSlice {
defaults: { img2img: TabState<Img2ImgParams>;
params: Required<BaseImgParams>;
update(newParams: Partial<BaseImgParams>): void;
};
txt2img: {
params: ConfigState<Required<Txt2ImgParams>>;
reset(): void; setImg2Img(params: Partial<Img2ImgParams>): void;
update(newParams: Partial<ConfigState<Required<Txt2ImgParams>>>): void; resetImg2Img(): void;
};
img2img: {
params: ConfigState<Required<Img2ImgParams>>;
reset(): void;
update(newParams: Partial<ConfigState<Required<Img2ImgParams>>>): void;
};
inpaint: {
params: ConfigState<Required<InpaintParams>>;
reset(): void;
update(newParams: Partial<ConfigState<Required<InpaintParams>>>): void;
};
history: {
images: Array<ApiResponse>;
limit: number;
loading: boolean;
setLimit(limit: number): void;
setLoading(loading: boolean): void;
setHistory(newHistory: Array<ApiResponse>): void;
pushHistory(newImage: ApiResponse): void;
};
} }
interface InpaintSlice {
inpaint: TabState<InpaintParams>;
setInpaint(params: Partial<InpaintParams>): void;
resetInpaint(): void;
}
interface HistorySlice {
history: Array<ApiResponse>;
limit: number;
loading: boolean;
pushHistory(image: ApiResponse): void;
removeHistory(image: ApiResponse): void;
setLimit(limit: number): void;
setLoading(loading: boolean): void;
}
interface DefaultSlice {
defaults: TabState<BaseImgParams>;
setDefaults(param: Partial<BaseImgParams>): void;
}
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice;
export function createStateSlices(base: ConfigParams) {
const defaults = paramsFromConfig(base);
const createTxt2ImgSlice: StateCreator<OnnxState, [], [], Txt2ImgSlice> = (set) => ({
txt2img: {
...defaults,
width: base.width.default,
height: base.height.default,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
...prev.txt2img,
...params,
},
}));
},
resetTxt2Img() {
set({
txt2img: {
...defaults,
width: base.width.default,
height: base.height.default,
},
});
},
});
const createImg2ImgSlice: StateCreator<OnnxState, [], [], Img2ImgSlice> = (set) => ({
img2img: {
...defaults,
strength: base.strength.default,
},
setImg2Img(params) {
set((prev) => ({
img2img: {
...prev.img2img,
...params,
},
}));
},
resetImg2Img() {
set({
img2img: {
...defaults,
strength: base.strength.default,
},
});
},
});
const createInpaintSlice: StateCreator<OnnxState, [], [], InpaintSlice> = (set) => ({
inpaint: {
...defaults,
},
setInpaint(params) {
set((prev) => ({
inpaint: {
...prev.inpaint,
...params,
},
}));
},
resetInpaint() {
set({
inpaint: {
...defaults,
},
});
},
});
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
history: [],
limit: 4,
loading: false,
pushHistory(image) {
set((prev) => ({
...prev,
history: [
image,
...prev.history,
],
}));
},
removeHistory(image) {
// ?
},
setLimit(limit) {
set((prev) => ({
...prev,
limit,
}));
},
setLoading(loading) {
set((prev) => ({
...prev,
loading,
}));
},
});
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({
defaults: {
...defaults,
},
setDefaults(params) {
set((prev) => ({
defaults: {
...prev.defaults,
...params,
}
}));
},
});
return {
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
};
}

View File

@ -42,7 +42,8 @@
"spinalcase", "spinalcase",
"stringcase", "stringcase",
"venv", "venv",
"virtualenv" "virtualenv",
"zustand"
] ]
} }
} }