fix(gui): break up state into slices for each tab
This commit is contained in:
parent
35e2e1dda6
commit
689a6a183f
|
@ -4,37 +4,32 @@ import { useContext } from 'react';
|
|||
import * as React from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ApiResponse } from '../api/client.js';
|
||||
import { StateContext } from '../main.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { LoadingCard } from './LoadingCard.js';
|
||||
|
||||
export function ImageHistory() {
|
||||
const history = useStore(mustExist(useContext(StateContext)), (state) => state.history);
|
||||
const limit = useStore(mustExist(useContext(StateContext)), (state) => state.limit);
|
||||
const loading = useStore(mustExist(useContext(StateContext)), (state) => state.loading);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setHistory = useStore(mustExist(useContext(StateContext)), (state) => state.setHistory);
|
||||
|
||||
const { images } = history;
|
||||
const removeHistory = useStore(mustExist(useContext(StateContext)), (state) => state.removeHistory);
|
||||
|
||||
const children = [];
|
||||
|
||||
if (history.loading) {
|
||||
if (loading) {
|
||||
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
|
||||
}
|
||||
|
||||
function removeHistory(image: ApiResponse) {
|
||||
setHistory(images.filter((item) => image.output !== item.output));
|
||||
}
|
||||
|
||||
if (images.length > 0) {
|
||||
children.push(...images.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
|
||||
if (history.length > 0) {
|
||||
children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
|
||||
} else {
|
||||
if (history.loading === false) {
|
||||
if (loading === false) {
|
||||
children.push(<div>No results. Press Generate.</div>);
|
||||
}
|
||||
}
|
||||
|
||||
const limited = children.slice(0, history.limit);
|
||||
const limited = children.slice(0, limit);
|
||||
|
||||
return <Grid container spacing={2}>{limited.map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ export function Settings(_props: SettingsProps) {
|
|||
min={2}
|
||||
max={20}
|
||||
step={1}
|
||||
value={state.history.limit}
|
||||
value={state.limit}
|
||||
onChange={(value) => state.setLimit(value)}
|
||||
/>
|
||||
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {
|
||||
|
|
162
gui/src/main.tsx
162
gui/src/main.tsx
|
@ -7,167 +7,26 @@ import { QueryClient, QueryClientProvider } from 'react-query';
|
|||
import { createStore, StoreApi } from 'zustand';
|
||||
import { createJSONStorage, persist } from 'zustand/middleware';
|
||||
|
||||
import { ApiClient, ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js';
|
||||
import { ApiClient, makeClient } from './api/client.js';
|
||||
import { OnnxWeb } from './components/OnnxWeb.js';
|
||||
import { ConfigState, loadConfig } from './config.js';
|
||||
import { loadConfig } from './config.js';
|
||||
import { createStateSlices, OnnxState } from './state.js';
|
||||
|
||||
const { createContext } = React;
|
||||
|
||||
interface OnnxState {
|
||||
defaults: Required<BaseImgParams>;
|
||||
txt2img: ConfigState<Required<Txt2ImgParams>>;
|
||||
img2img: ConfigState<Required<Img2ImgParams>>;
|
||||
inpaint: ConfigState<Required<InpaintParams>>;
|
||||
|
||||
setDefaults(newParams: Partial<BaseImgParams>): void;
|
||||
setTxt2Img(newParams: Partial<ConfigState<Required<Txt2ImgParams>>>): void;
|
||||
setImg2Img(newParams: Partial<ConfigState<Required<Img2ImgParams>>>): void;
|
||||
setInpaint(newParams: Partial<ConfigState<Required<InpaintParams>>>): void;
|
||||
|
||||
resetTxt2Img(): void;
|
||||
resetImg2Img(): void;
|
||||
resetInpaint(): void;
|
||||
|
||||
history: {
|
||||
images: Array<ApiResponse>;
|
||||
limit: number;
|
||||
loading: boolean;
|
||||
};
|
||||
|
||||
setLimit(limit: number): void;
|
||||
setLoading(loading: boolean): void;
|
||||
setHistory(newHistory: Array<ApiResponse>): void;
|
||||
pushHistory(newImage: ApiResponse): void;
|
||||
}
|
||||
|
||||
export async function main() {
|
||||
const config = await loadConfig();
|
||||
const client = makeClient(config.api.root);
|
||||
const params = await client.params();
|
||||
merge(params, config.params);
|
||||
|
||||
const defaults = paramsFromConfig(params);
|
||||
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((set) => ({
|
||||
defaults,
|
||||
history: {
|
||||
images: [],
|
||||
limit: 4,
|
||||
loading: false,
|
||||
},
|
||||
txt2img: {
|
||||
...defaults,
|
||||
height: params.height.default,
|
||||
width: params.width.default,
|
||||
},
|
||||
img2img: {
|
||||
...defaults,
|
||||
strength: params.strength.default,
|
||||
},
|
||||
inpaint: {
|
||||
...defaults,
|
||||
},
|
||||
setDefaults(newParams) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
defaults: {
|
||||
...oldState.defaults,
|
||||
...newParams,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setTxt2Img(newParams) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
txt2img: {
|
||||
...oldState.txt2img,
|
||||
...newParams,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setImg2Img(newParams) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
img2img: {
|
||||
...oldState.img2img,
|
||||
...newParams,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setInpaint(newParams) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
inpaint: {
|
||||
...oldState.inpaint,
|
||||
...newParams,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetTxt2Img() {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
txt2img: {
|
||||
...defaults,
|
||||
height: params.height.default,
|
||||
width: params.width.default,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetImg2Img() {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
img2img: {
|
||||
...defaults,
|
||||
strength: params.strength.default,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetInpaint() {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
inpaint: {
|
||||
...defaults,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setLimit(limit) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
limit,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setLoading(loading) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
loading,
|
||||
},
|
||||
}));
|
||||
},
|
||||
pushHistory(newImage: ApiResponse) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
images: [
|
||||
newImage,
|
||||
...oldState.history.images,
|
||||
].slice(0, oldState.history.limit),
|
||||
},
|
||||
}));
|
||||
},
|
||||
setHistory(newHistory: Array<ApiResponse>) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
images: newHistory,
|
||||
},
|
||||
}));
|
||||
},
|
||||
const { createDefaultSlice, createHistorySlice, createImg2ImgSlice, createInpaintSlice, createTxt2ImgSlice } = createStateSlices(params);
|
||||
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
|
||||
...createTxt2ImgSlice(...slice),
|
||||
...createImg2ImgSlice(...slice),
|
||||
...createInpaintSlice(...slice),
|
||||
...createHistorySlice(...slice),
|
||||
...createDefaultSlice(...slice),
|
||||
}), {
|
||||
name: 'onnx-web',
|
||||
partialize: (oldState) => ({
|
||||
|
@ -178,6 +37,7 @@ export async function main() {
|
|||
},
|
||||
}),
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
version: 1,
|
||||
}));
|
||||
|
||||
const query = new QueryClient();
|
||||
|
|
187
gui/src/state.ts
187
gui/src/state.ts
|
@ -1,45 +1,170 @@
|
|||
import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, Txt2ImgParams } from './api/client.js';
|
||||
import { ConfigState } from './config.js';
|
||||
import { createStore, StateCreator } from 'zustand';
|
||||
import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, paramsFromConfig, Txt2ImgParams } from './api/client.js';
|
||||
import { ConfigParams, ConfigState } from './config.js';
|
||||
|
||||
interface TabState<TabParams extends BaseImgParams> {
|
||||
params: ConfigState<Required<TabParams>>;
|
||||
type TabState<TabParams extends BaseImgParams> = ConfigState<Required<TabParams>>;
|
||||
|
||||
reset(): void;
|
||||
update(params: Partial<ConfigState<Required<TabParams>>>): void;
|
||||
interface Txt2ImgSlice {
|
||||
txt2img: TabState<Txt2ImgParams>;
|
||||
|
||||
setTxt2Img(params: Partial<Txt2ImgParams>): void;
|
||||
resetTxt2Img(): void;
|
||||
}
|
||||
|
||||
interface OnnxState {
|
||||
defaults: {
|
||||
params: Required<BaseImgParams>;
|
||||
update(newParams: Partial<BaseImgParams>): void;
|
||||
};
|
||||
txt2img: {
|
||||
params: ConfigState<Required<Txt2ImgParams>>;
|
||||
interface Img2ImgSlice {
|
||||
img2img: TabState<Img2ImgParams>;
|
||||
|
||||
reset(): void;
|
||||
update(newParams: Partial<ConfigState<Required<Txt2ImgParams>>>): void;
|
||||
};
|
||||
img2img: {
|
||||
params: ConfigState<Required<Img2ImgParams>>;
|
||||
setImg2Img(params: Partial<Img2ImgParams>): void;
|
||||
resetImg2Img(): void;
|
||||
}
|
||||
|
||||
reset(): void;
|
||||
update(newParams: Partial<ConfigState<Required<Img2ImgParams>>>): void;
|
||||
};
|
||||
inpaint: {
|
||||
params: ConfigState<Required<InpaintParams>>;
|
||||
interface InpaintSlice {
|
||||
inpaint: TabState<InpaintParams>;
|
||||
|
||||
reset(): void;
|
||||
update(newParams: Partial<ConfigState<Required<InpaintParams>>>): void;
|
||||
};
|
||||
history: {
|
||||
images: Array<ApiResponse>;
|
||||
setInpaint(params: Partial<InpaintParams>): void;
|
||||
resetInpaint(): void;
|
||||
}
|
||||
|
||||
interface HistorySlice {
|
||||
history: Array<ApiResponse>;
|
||||
limit: number;
|
||||
loading: boolean;
|
||||
|
||||
pushHistory(image: ApiResponse): void;
|
||||
removeHistory(image: ApiResponse): void;
|
||||
setLimit(limit: number): void;
|
||||
setLoading(loading: boolean): void;
|
||||
setHistory(newHistory: Array<ApiResponse>): void;
|
||||
pushHistory(newImage: ApiResponse): void;
|
||||
};
|
||||
}
|
||||
|
||||
interface DefaultSlice {
|
||||
defaults: TabState<BaseImgParams>;
|
||||
|
||||
setDefaults(param: Partial<BaseImgParams>): void;
|
||||
}
|
||||
|
||||
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice;
|
||||
|
||||
export function createStateSlices(base: ConfigParams) {
|
||||
const defaults = paramsFromConfig(base);
|
||||
|
||||
const createTxt2ImgSlice: StateCreator<OnnxState, [], [], Txt2ImgSlice> = (set) => ({
|
||||
txt2img: {
|
||||
...defaults,
|
||||
width: base.width.default,
|
||||
height: base.height.default,
|
||||
},
|
||||
setTxt2Img(params) {
|
||||
set((prev) => ({
|
||||
txt2img: {
|
||||
...prev.txt2img,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetTxt2Img() {
|
||||
set({
|
||||
txt2img: {
|
||||
...defaults,
|
||||
width: base.width.default,
|
||||
height: base.height.default,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createImg2ImgSlice: StateCreator<OnnxState, [], [], Img2ImgSlice> = (set) => ({
|
||||
img2img: {
|
||||
...defaults,
|
||||
strength: base.strength.default,
|
||||
},
|
||||
setImg2Img(params) {
|
||||
set((prev) => ({
|
||||
img2img: {
|
||||
...prev.img2img,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetImg2Img() {
|
||||
set({
|
||||
img2img: {
|
||||
...defaults,
|
||||
strength: base.strength.default,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createInpaintSlice: StateCreator<OnnxState, [], [], InpaintSlice> = (set) => ({
|
||||
inpaint: {
|
||||
...defaults,
|
||||
},
|
||||
setInpaint(params) {
|
||||
set((prev) => ({
|
||||
inpaint: {
|
||||
...prev.inpaint,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetInpaint() {
|
||||
set({
|
||||
inpaint: {
|
||||
...defaults,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
|
||||
history: [],
|
||||
limit: 4,
|
||||
loading: false,
|
||||
pushHistory(image) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
history: [
|
||||
image,
|
||||
...prev.history,
|
||||
],
|
||||
}));
|
||||
},
|
||||
removeHistory(image) {
|
||||
// ?
|
||||
},
|
||||
setLimit(limit) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
limit,
|
||||
}));
|
||||
},
|
||||
setLoading(loading) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
loading,
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({
|
||||
defaults: {
|
||||
...defaults,
|
||||
},
|
||||
setDefaults(params) {
|
||||
set((prev) => ({
|
||||
defaults: {
|
||||
...prev.defaults,
|
||||
...params,
|
||||
}
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
createDefaultSlice,
|
||||
createHistorySlice,
|
||||
createImg2ImgSlice,
|
||||
createInpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
};
|
||||
}
|
||||
|
|
|
@ -42,7 +42,8 @@
|
|||
"spinalcase",
|
||||
"stringcase",
|
||||
"venv",
|
||||
"virtualenv"
|
||||
"virtualenv",
|
||||
"zustand"
|
||||
]
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue