1
0
Fork 0

fix(gui): break up state into slices for each tab

This commit is contained in:
Sean Sube 2023-01-12 17:49:51 -06:00
parent 35e2e1dda6
commit 689a6a183f
5 changed files with 185 additions and 204 deletions

View File

@ -4,37 +4,32 @@ import { useContext } from 'react';
import * as React from 'react';
import { useStore } from 'zustand';
import { ApiResponse } from '../api/client.js';
import { StateContext } from '../main.js';
import { ImageCard } from './ImageCard.js';
import { LoadingCard } from './LoadingCard.js';
export function ImageHistory() {
const history = useStore(mustExist(useContext(StateContext)), (state) => state.history);
const limit = useStore(mustExist(useContext(StateContext)), (state) => state.limit);
const loading = useStore(mustExist(useContext(StateContext)), (state) => state.loading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHistory = useStore(mustExist(useContext(StateContext)), (state) => state.setHistory);
const { images } = history;
const removeHistory = useStore(mustExist(useContext(StateContext)), (state) => state.removeHistory);
const children = [];
if (history.loading) {
if (loading) {
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
}
function removeHistory(image: ApiResponse) {
setHistory(images.filter((item) => image.output !== item.output));
}
if (images.length > 0) {
children.push(...images.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
if (history.length > 0) {
children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
} else {
if (history.loading === false) {
if (loading === false) {
children.push(<div>No results. Press Generate.</div>);
}
}
const limited = children.slice(0, history.limit);
const limited = children.slice(0, limit);
return <Grid container spacing={2}>{limited.map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>;
}

View File

@ -22,7 +22,7 @@ export function Settings(_props: SettingsProps) {
min={2}
max={20}
step={1}
value={state.history.limit}
value={state.limit}
onChange={(value) => state.setLimit(value)}
/>
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {

View File

@ -7,167 +7,26 @@ import { QueryClient, QueryClientProvider } from 'react-query';
import { createStore, StoreApi } from 'zustand';
import { createJSONStorage, persist } from 'zustand/middleware';
import { ApiClient, ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js';
import { ApiClient, makeClient } from './api/client.js';
import { OnnxWeb } from './components/OnnxWeb.js';
import { ConfigState, loadConfig } from './config.js';
import { loadConfig } from './config.js';
import { createStateSlices, OnnxState } from './state.js';
const { createContext } = React;
interface OnnxState {
defaults: Required<BaseImgParams>;
txt2img: ConfigState<Required<Txt2ImgParams>>;
img2img: ConfigState<Required<Img2ImgParams>>;
inpaint: ConfigState<Required<InpaintParams>>;
setDefaults(newParams: Partial<BaseImgParams>): void;
setTxt2Img(newParams: Partial<ConfigState<Required<Txt2ImgParams>>>): void;
setImg2Img(newParams: Partial<ConfigState<Required<Img2ImgParams>>>): void;
setInpaint(newParams: Partial<ConfigState<Required<InpaintParams>>>): void;
resetTxt2Img(): void;
resetImg2Img(): void;
resetInpaint(): void;
history: {
images: Array<ApiResponse>;
limit: number;
loading: boolean;
};
setLimit(limit: number): void;
setLoading(loading: boolean): void;
setHistory(newHistory: Array<ApiResponse>): void;
pushHistory(newImage: ApiResponse): void;
}
export async function main() {
const config = await loadConfig();
const client = makeClient(config.api.root);
const params = await client.params();
merge(params, config.params);
const defaults = paramsFromConfig(params);
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((set) => ({
defaults,
history: {
images: [],
limit: 4,
loading: false,
},
txt2img: {
...defaults,
height: params.height.default,
width: params.width.default,
},
img2img: {
...defaults,
strength: params.strength.default,
},
inpaint: {
...defaults,
},
setDefaults(newParams) {
set((oldState) => ({
...oldState,
defaults: {
...oldState.defaults,
...newParams,
},
}));
},
setTxt2Img(newParams) {
set((oldState) => ({
...oldState,
txt2img: {
...oldState.txt2img,
...newParams,
},
}));
},
setImg2Img(newParams) {
set((oldState) => ({
...oldState,
img2img: {
...oldState.img2img,
...newParams,
},
}));
},
setInpaint(newParams) {
set((oldState) => ({
...oldState,
inpaint: {
...oldState.inpaint,
...newParams,
},
}));
},
resetTxt2Img() {
set((oldState) => ({
...oldState,
txt2img: {
...defaults,
height: params.height.default,
width: params.width.default,
},
}));
},
resetImg2Img() {
set((oldState) => ({
...oldState,
img2img: {
...defaults,
strength: params.strength.default,
},
}));
},
resetInpaint() {
set((oldState) => ({
...oldState,
inpaint: {
...defaults,
},
}));
},
setLimit(limit) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
limit,
},
}));
},
setLoading(loading) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
loading,
},
}));
},
pushHistory(newImage: ApiResponse) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: [
newImage,
...oldState.history.images,
].slice(0, oldState.history.limit),
},
}));
},
setHistory(newHistory: Array<ApiResponse>) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: newHistory,
},
}));
},
const { createDefaultSlice, createHistorySlice, createImg2ImgSlice, createInpaintSlice, createTxt2ImgSlice } = createStateSlices(params);
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
...createTxt2ImgSlice(...slice),
...createImg2ImgSlice(...slice),
...createInpaintSlice(...slice),
...createHistorySlice(...slice),
...createDefaultSlice(...slice),
}), {
name: 'onnx-web',
partialize: (oldState) => ({
@ -178,6 +37,7 @@ export async function main() {
},
}),
storage: createJSONStorage(() => localStorage),
version: 1,
}));
const query = new QueryClient();

View File

@ -1,45 +1,170 @@
import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, Txt2ImgParams } from './api/client.js';
import { ConfigState } from './config.js';
import { createStore, StateCreator } from 'zustand';
import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, paramsFromConfig, Txt2ImgParams } from './api/client.js';
import { ConfigParams, ConfigState } from './config.js';
interface TabState<TabParams extends BaseImgParams> {
params: ConfigState<Required<TabParams>>;
type TabState<TabParams extends BaseImgParams> = ConfigState<Required<TabParams>>;
reset(): void;
update(params: Partial<ConfigState<Required<TabParams>>>): void;
interface Txt2ImgSlice {
txt2img: TabState<Txt2ImgParams>;
setTxt2Img(params: Partial<Txt2ImgParams>): void;
resetTxt2Img(): void;
}
interface OnnxState {
defaults: {
params: Required<BaseImgParams>;
update(newParams: Partial<BaseImgParams>): void;
};
txt2img: {
params: ConfigState<Required<Txt2ImgParams>>;
interface Img2ImgSlice {
img2img: TabState<Img2ImgParams>;
reset(): void;
update(newParams: Partial<ConfigState<Required<Txt2ImgParams>>>): void;
};
img2img: {
params: ConfigState<Required<Img2ImgParams>>;
reset(): void;
update(newParams: Partial<ConfigState<Required<Img2ImgParams>>>): void;
};
inpaint: {
params: ConfigState<Required<InpaintParams>>;
reset(): void;
update(newParams: Partial<ConfigState<Required<InpaintParams>>>): void;
};
history: {
images: Array<ApiResponse>;
limit: number;
loading: boolean;
setLimit(limit: number): void;
setLoading(loading: boolean): void;
setHistory(newHistory: Array<ApiResponse>): void;
pushHistory(newImage: ApiResponse): void;
};
setImg2Img(params: Partial<Img2ImgParams>): void;
resetImg2Img(): void;
}
interface InpaintSlice {
inpaint: TabState<InpaintParams>;
setInpaint(params: Partial<InpaintParams>): void;
resetInpaint(): void;
}
interface HistorySlice {
history: Array<ApiResponse>;
limit: number;
loading: boolean;
pushHistory(image: ApiResponse): void;
removeHistory(image: ApiResponse): void;
setLimit(limit: number): void;
setLoading(loading: boolean): void;
}
interface DefaultSlice {
defaults: TabState<BaseImgParams>;
setDefaults(param: Partial<BaseImgParams>): void;
}
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice;
export function createStateSlices(base: ConfigParams) {
const defaults = paramsFromConfig(base);
const createTxt2ImgSlice: StateCreator<OnnxState, [], [], Txt2ImgSlice> = (set) => ({
txt2img: {
...defaults,
width: base.width.default,
height: base.height.default,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
...prev.txt2img,
...params,
},
}));
},
resetTxt2Img() {
set({
txt2img: {
...defaults,
width: base.width.default,
height: base.height.default,
},
});
},
});
const createImg2ImgSlice: StateCreator<OnnxState, [], [], Img2ImgSlice> = (set) => ({
img2img: {
...defaults,
strength: base.strength.default,
},
setImg2Img(params) {
set((prev) => ({
img2img: {
...prev.img2img,
...params,
},
}));
},
resetImg2Img() {
set({
img2img: {
...defaults,
strength: base.strength.default,
},
});
},
});
const createInpaintSlice: StateCreator<OnnxState, [], [], InpaintSlice> = (set) => ({
inpaint: {
...defaults,
},
setInpaint(params) {
set((prev) => ({
inpaint: {
...prev.inpaint,
...params,
},
}));
},
resetInpaint() {
set({
inpaint: {
...defaults,
},
});
},
});
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
history: [],
limit: 4,
loading: false,
pushHistory(image) {
set((prev) => ({
...prev,
history: [
image,
...prev.history,
],
}));
},
removeHistory(image) {
// ?
},
setLimit(limit) {
set((prev) => ({
...prev,
limit,
}));
},
setLoading(loading) {
set((prev) => ({
...prev,
loading,
}));
},
});
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({
defaults: {
...defaults,
},
setDefaults(params) {
set((prev) => ({
defaults: {
...prev.defaults,
...params,
}
}));
},
});
return {
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
};
}

View File

@ -42,7 +42,8 @@
"spinalcase",
"stringcase",
"venv",
"virtualenv"
"virtualenv",
"zustand"
]
}
}