diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx
index 443dfebd..cef39158 100644
--- a/gui/src/components/ImageHistory.tsx
+++ b/gui/src/components/ImageHistory.tsx
@@ -4,37 +4,32 @@ import { useContext } from 'react';
import * as React from 'react';
import { useStore } from 'zustand';
-import { ApiResponse } from '../api/client.js';
import { StateContext } from '../main.js';
import { ImageCard } from './ImageCard.js';
import { LoadingCard } from './LoadingCard.js';
export function ImageHistory() {
const history = useStore(mustExist(useContext(StateContext)), (state) => state.history);
+ const limit = useStore(mustExist(useContext(StateContext)), (state) => state.limit);
+ const loading = useStore(mustExist(useContext(StateContext)), (state) => state.loading);
// eslint-disable-next-line @typescript-eslint/unbound-method
- const setHistory = useStore(mustExist(useContext(StateContext)), (state) => state.setHistory);
-
- const { images } = history;
+ const removeHistory = useStore(mustExist(useContext(StateContext)), (state) => state.removeHistory);
const children = [];
- if (history.loading) {
+ if (loading) {
children.push(); // TODO: get dimensions from config
}
- function removeHistory(image: ApiResponse) {
- setHistory(images.filter((item) => image.output !== item.output));
- }
-
- if (images.length > 0) {
- children.push(...images.map((item) => ));
+ if (history.length > 0) {
+ children.push(...history.map((item) => ));
} else {
- if (history.loading === false) {
+ if (loading === false) {
children.push(
No results. Press Generate.
);
}
}
- const limited = children.slice(0, history.limit);
+ const limited = children.slice(0, limit);
return {limited.map((child, idx) => {child})};
}
diff --git a/gui/src/components/Settings.tsx b/gui/src/components/Settings.tsx
index 3e69a14d..7b3388a6 100644
--- a/gui/src/components/Settings.tsx
+++ b/gui/src/components/Settings.tsx
@@ -22,7 +22,7 @@ export function Settings(_props: SettingsProps) {
min={2}
max={20}
step={1}
- value={state.history.limit}
+ value={state.limit}
onChange={(value) => state.setLimit(value)}
/>
{
diff --git a/gui/src/main.tsx b/gui/src/main.tsx
index e3bf5118..1c8ed814 100644
--- a/gui/src/main.tsx
+++ b/gui/src/main.tsx
@@ -7,167 +7,26 @@ import { QueryClient, QueryClientProvider } from 'react-query';
import { createStore, StoreApi } from 'zustand';
import { createJSONStorage, persist } from 'zustand/middleware';
-import { ApiClient, ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js';
+import { ApiClient, makeClient } from './api/client.js';
import { OnnxWeb } from './components/OnnxWeb.js';
-import { ConfigState, loadConfig } from './config.js';
+import { loadConfig } from './config.js';
+import { createStateSlices, OnnxState } from './state.js';
const { createContext } = React;
-interface OnnxState {
- defaults: Required;
- txt2img: ConfigState>;
- img2img: ConfigState>;
- inpaint: ConfigState>;
-
- setDefaults(newParams: Partial): void;
- setTxt2Img(newParams: Partial>>): void;
- setImg2Img(newParams: Partial>>): void;
- setInpaint(newParams: Partial>>): void;
-
- resetTxt2Img(): void;
- resetImg2Img(): void;
- resetInpaint(): void;
-
- history: {
- images: Array;
- limit: number;
- loading: boolean;
- };
-
- setLimit(limit: number): void;
- setLoading(loading: boolean): void;
- setHistory(newHistory: Array): void;
- pushHistory(newImage: ApiResponse): void;
-}
-
export async function main() {
const config = await loadConfig();
const client = makeClient(config.api.root);
const params = await client.params();
merge(params, config.params);
- const defaults = paramsFromConfig(params);
- const state = createStore(persist((set) => ({
- defaults,
- history: {
- images: [],
- limit: 4,
- loading: false,
- },
- txt2img: {
- ...defaults,
- height: params.height.default,
- width: params.width.default,
- },
- img2img: {
- ...defaults,
- strength: params.strength.default,
- },
- inpaint: {
- ...defaults,
- },
- setDefaults(newParams) {
- set((oldState) => ({
- ...oldState,
- defaults: {
- ...oldState.defaults,
- ...newParams,
- },
- }));
- },
- setTxt2Img(newParams) {
- set((oldState) => ({
- ...oldState,
- txt2img: {
- ...oldState.txt2img,
- ...newParams,
- },
- }));
- },
- setImg2Img(newParams) {
- set((oldState) => ({
- ...oldState,
- img2img: {
- ...oldState.img2img,
- ...newParams,
- },
- }));
- },
- setInpaint(newParams) {
- set((oldState) => ({
- ...oldState,
- inpaint: {
- ...oldState.inpaint,
- ...newParams,
- },
- }));
- },
- resetTxt2Img() {
- set((oldState) => ({
- ...oldState,
- txt2img: {
- ...defaults,
- height: params.height.default,
- width: params.width.default,
- },
- }));
- },
- resetImg2Img() {
- set((oldState) => ({
- ...oldState,
- img2img: {
- ...defaults,
- strength: params.strength.default,
- },
- }));
- },
- resetInpaint() {
- set((oldState) => ({
- ...oldState,
- inpaint: {
- ...defaults,
- },
- }));
- },
- setLimit(limit) {
- set((oldState) => ({
- ...oldState,
- history: {
- ...oldState.history,
- limit,
- },
- }));
- },
- setLoading(loading) {
- set((oldState) => ({
- ...oldState,
- history: {
- ...oldState.history,
- loading,
- },
- }));
- },
- pushHistory(newImage: ApiResponse) {
- set((oldState) => ({
- ...oldState,
- history: {
- ...oldState.history,
- images: [
- newImage,
- ...oldState.history.images,
- ].slice(0, oldState.history.limit),
- },
- }));
- },
- setHistory(newHistory: Array) {
- set((oldState) => ({
- ...oldState,
- history: {
- ...oldState.history,
- images: newHistory,
- },
- }));
- },
+ const { createDefaultSlice, createHistorySlice, createImg2ImgSlice, createInpaintSlice, createTxt2ImgSlice } = createStateSlices(params);
+ const state = createStore(persist((...slice) => ({
+ ...createTxt2ImgSlice(...slice),
+ ...createImg2ImgSlice(...slice),
+ ...createInpaintSlice(...slice),
+ ...createHistorySlice(...slice),
+ ...createDefaultSlice(...slice),
}), {
name: 'onnx-web',
partialize: (oldState) => ({
@@ -178,6 +37,7 @@ export async function main() {
},
}),
storage: createJSONStorage(() => localStorage),
+ version: 1,
}));
const query = new QueryClient();
diff --git a/gui/src/state.ts b/gui/src/state.ts
index 7ef2fe4f..69abdd2d 100644
--- a/gui/src/state.ts
+++ b/gui/src/state.ts
@@ -1,45 +1,170 @@
-import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, Txt2ImgParams } from './api/client.js';
-import { ConfigState } from './config.js';
+import { createStore, StateCreator } from 'zustand';
+import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, paramsFromConfig, Txt2ImgParams } from './api/client.js';
+import { ConfigParams, ConfigState } from './config.js';
-interface TabState {
- params: ConfigState>;
+type TabState = ConfigState>;
- reset(): void;
- update(params: Partial>>): void;
+interface Txt2ImgSlice {
+ txt2img: TabState;
+
+ setTxt2Img(params: Partial): void;
+ resetTxt2Img(): void;
}
-interface OnnxState {
- defaults: {
- params: Required;
- update(newParams: Partial): void;
- };
- txt2img: {
- params: ConfigState>;
+interface Img2ImgSlice {
+ img2img: TabState;
- reset(): void;
- update(newParams: Partial>>): void;
- };
- img2img: {
- params: ConfigState>;
-
- reset(): void;
- update(newParams: Partial>>): void;
- };
- inpaint: {
- params: ConfigState>;
-
- reset(): void;
- update(newParams: Partial>>): void;
- };
- history: {
- images: Array;
- limit: number;
- loading: boolean;
-
- setLimit(limit: number): void;
- setLoading(loading: boolean): void;
- setHistory(newHistory: Array): void;
- pushHistory(newImage: ApiResponse): void;
- };
+ setImg2Img(params: Partial): void;
+ resetImg2Img(): void;
}
+interface InpaintSlice {
+ inpaint: TabState;
+
+ setInpaint(params: Partial): void;
+ resetInpaint(): void;
+}
+
+interface HistorySlice {
+ history: Array;
+ limit: number;
+ loading: boolean;
+
+ pushHistory(image: ApiResponse): void;
+ removeHistory(image: ApiResponse): void;
+ setLimit(limit: number): void;
+ setLoading(loading: boolean): void;
+}
+
+interface DefaultSlice {
+ defaults: TabState;
+
+ setDefaults(param: Partial): void;
+}
+
+export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice;
+
+export function createStateSlices(base: ConfigParams) {
+ const defaults = paramsFromConfig(base);
+
+ const createTxt2ImgSlice: StateCreator = (set) => ({
+ txt2img: {
+ ...defaults,
+ width: base.width.default,
+ height: base.height.default,
+ },
+ setTxt2Img(params) {
+ set((prev) => ({
+ txt2img: {
+ ...prev.txt2img,
+ ...params,
+ },
+ }));
+ },
+ resetTxt2Img() {
+ set({
+ txt2img: {
+ ...defaults,
+ width: base.width.default,
+ height: base.height.default,
+ },
+ });
+ },
+ });
+
+ const createImg2ImgSlice: StateCreator = (set) => ({
+ img2img: {
+ ...defaults,
+ strength: base.strength.default,
+ },
+ setImg2Img(params) {
+ set((prev) => ({
+ img2img: {
+ ...prev.img2img,
+ ...params,
+ },
+ }));
+ },
+ resetImg2Img() {
+ set({
+ img2img: {
+ ...defaults,
+ strength: base.strength.default,
+ },
+ });
+ },
+ });
+
+ const createInpaintSlice: StateCreator = (set) => ({
+ inpaint: {
+ ...defaults,
+ },
+ setInpaint(params) {
+ set((prev) => ({
+ inpaint: {
+ ...prev.inpaint,
+ ...params,
+ },
+ }));
+ },
+ resetInpaint() {
+ set({
+ inpaint: {
+ ...defaults,
+ },
+ });
+ },
+ });
+
+ const createHistorySlice: StateCreator = (set) => ({
+ history: [],
+ limit: 4,
+ loading: false,
+ pushHistory(image) {
+ set((prev) => ({
+ ...prev,
+ history: [
+ image,
+ ...prev.history,
+ ],
+ }));
+ },
+ removeHistory(image) {
+ // ?
+ },
+ setLimit(limit) {
+ set((prev) => ({
+ ...prev,
+ limit,
+ }));
+ },
+ setLoading(loading) {
+ set((prev) => ({
+ ...prev,
+ loading,
+ }));
+ },
+ });
+
+ const createDefaultSlice: StateCreator = (set) => ({
+ defaults: {
+ ...defaults,
+ },
+ setDefaults(params) {
+ set((prev) => ({
+ defaults: {
+ ...prev.defaults,
+ ...params,
+ }
+ }));
+ },
+ });
+
+ return {
+ createDefaultSlice,
+ createHistorySlice,
+ createImg2ImgSlice,
+ createInpaintSlice,
+ createTxt2ImgSlice,
+ };
+}
diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace
index 1f5a39af..8daae773 100644
--- a/onnx-web.code-workspace
+++ b/onnx-web.code-workspace
@@ -42,7 +42,8 @@
"spinalcase",
"stringcase",
"venv",
- "virtualenv"
+ "virtualenv",
+ "zustand"
]
}
}
\ No newline at end of file