From 662bf42454c31df6f439d440f0f2cfe4d59397da Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 10 Jan 2023 22:35:55 -0600 Subject: [PATCH] feat(gui): share image history between tabs, add setting to adjust length of history (fixes #22) --- gui/src/components/ImageHistory.tsx | 37 ++++++++++++++++ gui/src/components/Img2Img.tsx | 13 +++--- gui/src/components/Inpaint.tsx | 18 ++++---- gui/src/components/MutationHistory.tsx | 58 -------------------------- gui/src/components/OnnxWeb.tsx | 9 +++- gui/src/components/Settings.tsx | 21 +++++++--- gui/src/components/Txt2Img.tsx | 18 ++++---- gui/src/main.tsx | 57 ++++++++++++++++++++++++- 8 files changed, 137 insertions(+), 94 deletions(-) create mode 100644 gui/src/components/ImageHistory.tsx delete mode 100644 gui/src/components/MutationHistory.tsx diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx new file mode 100644 index 00000000..2d135990 --- /dev/null +++ b/gui/src/components/ImageHistory.tsx @@ -0,0 +1,37 @@ +import { mustExist } from '@apextoaster/js-utils'; +import { Grid } from '@mui/material'; +import { useContext } from 'react'; +import * as React from 'react'; +import { useStore } from 'zustand'; + +import { ApiResponse } from '../api/client.js'; +import { StateContext } from '../main.js'; +import { ImageCard } from './ImageCard.js'; +import { LoadingCard } from './LoadingCard.js'; + +export function ImageHistory() { + const state = useStore(mustExist(useContext(StateContext))); + const { images } = state.history; + + const children = []; + + if (state.history.loading) { + children.push(); // TODO: get dimensions from config + } + + function removeHistory(image: ApiResponse) { + state.setHistory(images.filter((item) => image.output !== item.output)); + } + + if (images.length > 0) { + children.push(...images.map((item) => )); + } else { + if (state.history.loading === false) { + children.push(
No results. Press Generate.
); + } + } + + const limited = children.slice(0, state.history.limit); + + return {limited.map((child, idx) => {child})}; +} diff --git a/gui/src/components/Img2Img.tsx b/gui/src/components/Img2Img.tsx index 31807a4d..d8fa2187 100644 --- a/gui/src/components/Img2Img.tsx +++ b/gui/src/components/Img2Img.tsx @@ -4,13 +4,10 @@ import * as React from 'react'; import { useMutation } from 'react-query'; import { useStore } from 'zustand'; -import { equalResponse } from '../api/client.js'; import { ConfigParams, IMAGE_FILTER } from '../config.js'; import { ClientContext, StateContext } from '../main.js'; -import { ImageCard } from './ImageCard.js'; import { ImageControl } from './ImageControl.js'; import { ImageInput } from './ImageInput.js'; -import { MutationHistory } from './MutationHistory.js'; import { NumericField } from './NumericField.js'; const { useContext, useState } = React; @@ -26,12 +23,17 @@ export function Img2Img(props: Img2ImgProps) { const { config, model, platform } = props; async function uploadSource() { - return client.img2img({ + state.setLoading(true); + + const output = await client.img2img({ ...state.img2img, model, platform, source: mustExist(source), // TODO: show an error if this doesn't exist }); + + state.pushHistory(output); + state.setLoading(false); } const client = mustExist(useContext(ClientContext)); @@ -60,9 +62,6 @@ export function Img2Img(props: Img2ImgProps) { }} /> - ; } diff --git a/gui/src/components/Inpaint.tsx b/gui/src/components/Inpaint.tsx index b286eaca..8443e0c8 100644 --- a/gui/src/components/Inpaint.tsx +++ b/gui/src/components/Inpaint.tsx @@ -5,13 +5,11 @@ import * as React from 'react'; import { useMutation } from 'react-query'; import { useStore } from 'zustand'; -import { ApiResponse, equalResponse } from '../api/client.js'; +import { ApiResponse } from '../api/client.js'; import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js'; import { ClientContext, StateContext } from '../main.js'; -import { ImageCard } from './ImageCard.js'; import { ImageControl } from './ImageControl.js'; import { ImageInput } from './ImageInput.js'; -import { MutationHistory } from './MutationHistory.js'; import { NumericField } from './NumericField.js'; const { useContext, useEffect, useRef, useState } = React; @@ -72,15 +70,20 @@ export function Inpaint(props: InpaintProps) { async function uploadSource() { const canvas = mustExist(canvasRef.current); - return new Promise((res, _rej) => { + state.setLoading(true); + return new Promise((res, rej) => { canvas.toBlob((blob) => { - res(client.inpaint({ + client.inpaint({ ...state.inpaint, model, platform, mask: mustExist(blob), source: mustExist(source), - })); + }).then((output) => { + state.pushHistory(output); + state.setLoading(false); + res(); + }).catch((err) => rej(err)); }); }); } @@ -262,9 +265,6 @@ export function Inpaint(props: InpaintProps) { }} /> - ; } diff --git a/gui/src/components/MutationHistory.tsx b/gui/src/components/MutationHistory.tsx deleted file mode 100644 index 331b2ca2..00000000 --- a/gui/src/components/MutationHistory.tsx +++ /dev/null @@ -1,58 +0,0 @@ -import { Grid } from '@mui/material'; -import { useState } from 'react'; -import * as React from 'react'; -import { UseMutationResult } from 'react-query'; -import { LoadingCard } from './LoadingCard.js'; - -export interface MutationHistoryChildProps { - value: T; - - onDelete: (key: T) => void; -} - -export interface MutationHistoryProps { - element: React.ComponentType>; - limit: number; - result: UseMutationResult; - - isEqual: (a: T, b: T) => boolean; -} - -export function MutationHistory(props: MutationHistoryProps) { - const { limit, result } = props; - const { status } = result; - - const [history, setHistory] = useState>([]); - const children = []; - - if (status === 'loading') { - children.push(); // TODO: get dimensions from parent - } - - if (status === 'success') { - const { data } = result; - if (history.some((other) => props.isEqual(data, other))) { - // item already exists, skip it - } else { - setHistory([ - data, - ...history, - ].slice(0, limit)); - } - } - - function removeHistory(data: T) { - setHistory(history.filter((item) => props.isEqual(item, data) === false)); - } - - if (history.length > 0) { - children.push(...history.map((item) => )); - } else { - // only show the prompt when the button has not been pushed - if (status !== 'loading') { - children.push(
No results. Press Generate.
); - } - } - - return {children.slice(0, limit).map((child, idx) => {child})}; -} diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index b1a27da9..eb970ad5 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -1,6 +1,6 @@ import { mustExist } from '@apextoaster/js-utils'; import { TabContext, TabList, TabPanel } from '@mui/lab'; -import { Box, Container, Stack, Tab, Typography } from '@mui/material'; +import { Box, Container, Divider, Stack, Tab, Typography } from '@mui/material'; import * as React from 'react'; import { useQuery } from 'react-query'; @@ -8,6 +8,7 @@ import { ApiClient } from '../api/client.js'; import { ConfigParams, STALE_TIME } from '../config.js'; import { ClientContext } from '../main.js'; import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js'; +import { ImageHistory } from './ImageHistory.js'; import { Img2Img } from './Img2Img.js'; import { Inpaint } from './Inpaint.js'; import { QueryList } from './QueryList.js'; @@ -44,7 +45,7 @@ export function OnnxWeb(props: OnnxWebProps) { ONNX Web - + + + + + ); diff --git a/gui/src/components/Settings.tsx b/gui/src/components/Settings.tsx index 2fc2aece..3e69a14d 100644 --- a/gui/src/components/Settings.tsx +++ b/gui/src/components/Settings.tsx @@ -5,6 +5,7 @@ import { useStore } from 'zustand'; import { ConfigParams } from '../config.js'; import { StateContext } from '../main.js'; +import { NumericField } from './NumericField.js'; const { useContext } = React; @@ -16,12 +17,14 @@ export function Settings(_props: SettingsProps) { const state = useStore(mustExist(useContext(StateContext))); return - - - - - - + state.setLimit(value)} + /> { state.setDefaults({ model: event.target.value, @@ -42,5 +45,11 @@ export function Settings(_props: SettingsProps) { scheduler: event.target.value, }); }} /> + + + + + + ; } diff --git a/gui/src/components/Txt2Img.tsx b/gui/src/components/Txt2Img.tsx index 33486dc5..b21bbc13 100644 --- a/gui/src/components/Txt2Img.tsx +++ b/gui/src/components/Txt2Img.tsx @@ -4,15 +4,12 @@ import * as React from 'react'; import { useMutation } from 'react-query'; import { useStore } from 'zustand'; -import { BaseImgParams, equalResponse, paramsFromConfig } from '../api/client.js'; import { ConfigParams } from '../config.js'; import { ClientContext, StateContext } from '../main.js'; -import { ImageCard } from './ImageCard.js'; import { ImageControl } from './ImageControl.js'; -import { MutationHistory } from './MutationHistory.js'; import { NumericField } from './NumericField.js'; -const { useContext, useState } = React; +const { useContext } = React; export interface Txt2ImgProps { config: ConfigParams; @@ -25,11 +22,16 @@ export function Txt2Img(props: Txt2ImgProps) { const { config, model, platform } = props; async function generateImage() { - return client.txt2img({ + state.setLoading(true); + + const output = await client.txt2img({ ...state.txt2img, model, platform, }); + + state.pushHistory(output); + state.setLoading(false); } const client = mustExist(useContext(ClientContext)); @@ -68,12 +70,6 @@ export function Txt2Img(props: Txt2ImgProps) { /> - ; } diff --git a/gui/src/main.tsx b/gui/src/main.tsx index 8c468d72..1f8828c7 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -7,7 +7,7 @@ import { QueryClient, QueryClientProvider } from 'react-query'; import { createStore, StoreApi } from 'zustand'; import { createJSONStorage, persist } from 'zustand/middleware'; -import { ApiClient, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js'; +import { ApiClient, ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js'; import { OnnxWeb } from './components/OnnxWeb.js'; import { ConfigState, loadConfig } from './config.js'; @@ -27,6 +27,17 @@ interface OnnxState { resetTxt2Img(): void; resetImg2Img(): void; resetInpaint(): void; + + history: { + images: Array; + limit: number; + loading: boolean; + }; + + setLimit(limit: number): void; + setLoading(loading: boolean): void; + setHistory(newHistory: Array): void; + pushHistory(newImage: ApiResponse): void; } export async function main() { @@ -38,6 +49,11 @@ export async function main() { const defaults = paramsFromConfig(params); const state = createStore(persist((set) => ({ defaults, + history: { + images: [], + limit: 4, + loading: false, + }, txt2img: { ...defaults, height: params.height.default, @@ -50,6 +66,45 @@ export async function main() { inpaint: { ...defaults, }, + setLimit(limit) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + limit, + }, + })); + }, + setLoading(loading) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + loading, + }, + })); + }, + pushHistory(newImage: ApiResponse) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + images: [ + newImage, + ...oldState.history.images, + ].slice(0, oldState.history.limit), + }, + })); + }, + setHistory(newHistory: Array) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + images: newHistory, + }, + })); + }, setDefaults(newParams) { set((oldState) => ({ ...oldState,