feat(gui): share image history between tabs, add setting to adjust length of history (fixes #22)
This commit is contained in:
parent
9bb01cc01d
commit
662bf42454
|
@ -0,0 +1,37 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Grid } from '@mui/material';
|
||||
import { useContext } from 'react';
|
||||
import * as React from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ApiResponse } from '../api/client.js';
|
||||
import { StateContext } from '../main.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { LoadingCard } from './LoadingCard.js';
|
||||
|
||||
export function ImageHistory() {
|
||||
const state = useStore(mustExist(useContext(StateContext)));
|
||||
const { images } = state.history;
|
||||
|
||||
const children = [];
|
||||
|
||||
if (state.history.loading) {
|
||||
children.push(<LoadingCard height={512} width={512} />); // TODO: get dimensions from config
|
||||
}
|
||||
|
||||
function removeHistory(image: ApiResponse) {
|
||||
state.setHistory(images.filter((item) => image.output !== item.output));
|
||||
}
|
||||
|
||||
if (images.length > 0) {
|
||||
children.push(...images.map((item) => <ImageCard value={item} onDelete={removeHistory} />));
|
||||
} else {
|
||||
if (state.history.loading === false) {
|
||||
children.push(<div>No results. Press Generate.</div>);
|
||||
}
|
||||
}
|
||||
|
||||
const limited = children.slice(0, state.history.limit);
|
||||
|
||||
return <Grid container spacing={2}>{limited.map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>;
|
||||
}
|
|
@ -4,13 +4,10 @@ import * as React from 'react';
|
|||
import { useMutation } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { equalResponse } from '../api/client.js';
|
||||
import { ConfigParams, IMAGE_FILTER } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../main.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { ImageInput } from './ImageInput.js';
|
||||
import { MutationHistory } from './MutationHistory.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
||||
const { useContext, useState } = React;
|
||||
|
@ -26,12 +23,17 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
const { config, model, platform } = props;
|
||||
|
||||
async function uploadSource() {
|
||||
return client.img2img({
|
||||
state.setLoading(true);
|
||||
|
||||
const output = await client.img2img({
|
||||
...state.img2img,
|
||||
model,
|
||||
platform,
|
||||
source: mustExist(source), // TODO: show an error if this doesn't exist
|
||||
});
|
||||
|
||||
state.pushHistory(output);
|
||||
state.setLoading(false);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
@ -60,9 +62,6 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
}}
|
||||
/>
|
||||
<Button onClick={() => upload.mutate()}>Generate</Button>
|
||||
<MutationHistory result={upload} limit={4} element={ImageCard}
|
||||
isEqual={equalResponse}
|
||||
/>
|
||||
</Stack>
|
||||
</Box>;
|
||||
}
|
||||
|
|
|
@ -5,13 +5,11 @@ import * as React from 'react';
|
|||
import { useMutation } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ApiResponse, equalResponse } from '../api/client.js';
|
||||
import { ApiResponse } from '../api/client.js';
|
||||
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../main.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { ImageInput } from './ImageInput.js';
|
||||
import { MutationHistory } from './MutationHistory.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
||||
const { useContext, useEffect, useRef, useState } = React;
|
||||
|
@ -72,15 +70,20 @@ export function Inpaint(props: InpaintProps) {
|
|||
|
||||
async function uploadSource() {
|
||||
const canvas = mustExist(canvasRef.current);
|
||||
return new Promise<ApiResponse>((res, _rej) => {
|
||||
state.setLoading(true);
|
||||
return new Promise<void>((res, rej) => {
|
||||
canvas.toBlob((blob) => {
|
||||
res(client.inpaint({
|
||||
client.inpaint({
|
||||
...state.inpaint,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(blob),
|
||||
source: mustExist(source),
|
||||
}));
|
||||
}).then((output) => {
|
||||
state.pushHistory(output);
|
||||
state.setLoading(false);
|
||||
res();
|
||||
}).catch((err) => rej(err));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -262,9 +265,6 @@ export function Inpaint(props: InpaintProps) {
|
|||
}}
|
||||
/>
|
||||
<Button onClick={() => upload.mutate()}>Generate</Button>
|
||||
<MutationHistory result={upload} limit={4} element={ImageCard}
|
||||
isEqual={equalResponse}
|
||||
/>
|
||||
</Stack>
|
||||
</Box>;
|
||||
}
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
import { Grid } from '@mui/material';
|
||||
import { useState } from 'react';
|
||||
import * as React from 'react';
|
||||
import { UseMutationResult } from 'react-query';
|
||||
import { LoadingCard } from './LoadingCard.js';
|
||||
|
||||
export interface MutationHistoryChildProps<T> {
|
||||
value: T;
|
||||
|
||||
onDelete: (key: T) => void;
|
||||
}
|
||||
|
||||
export interface MutationHistoryProps<T> {
|
||||
element: React.ComponentType<MutationHistoryChildProps<T>>;
|
||||
limit: number;
|
||||
result: UseMutationResult<T, unknown, void>;
|
||||
|
||||
isEqual: (a: T, b: T) => boolean;
|
||||
}
|
||||
|
||||
export function MutationHistory<T>(props: MutationHistoryProps<T>) {
|
||||
const { limit, result } = props;
|
||||
const { status } = result;
|
||||
|
||||
const [history, setHistory] = useState<Array<T>>([]);
|
||||
const children = [];
|
||||
|
||||
if (status === 'loading') {
|
||||
children.push(<LoadingCard height={512} width={512} />); // TODO: get dimensions from parent
|
||||
}
|
||||
|
||||
if (status === 'success') {
|
||||
const { data } = result;
|
||||
if (history.some((other) => props.isEqual(data, other))) {
|
||||
// item already exists, skip it
|
||||
} else {
|
||||
setHistory([
|
||||
data,
|
||||
...history,
|
||||
].slice(0, limit));
|
||||
}
|
||||
}
|
||||
|
||||
function removeHistory(data: T) {
|
||||
setHistory(history.filter((item) => props.isEqual(item, data) === false));
|
||||
}
|
||||
|
||||
if (history.length > 0) {
|
||||
children.push(...history.map((item) => <props.element value={item} onDelete={removeHistory} />));
|
||||
} else {
|
||||
// only show the prompt when the button has not been pushed
|
||||
if (status !== 'loading') {
|
||||
children.push(<div>No results. Press Generate.</div>);
|
||||
}
|
||||
}
|
||||
|
||||
return <Grid container spacing={2}>{children.slice(0, limit).map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>;
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
||||
import { Box, Container, Stack, Tab, Typography } from '@mui/material';
|
||||
import { Box, Container, Divider, Stack, Tab, Typography } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useQuery } from 'react-query';
|
||||
|
||||
|
@ -8,6 +8,7 @@ import { ApiClient } from '../api/client.js';
|
|||
import { ConfigParams, STALE_TIME } from '../config.js';
|
||||
import { ClientContext } from '../main.js';
|
||||
import { MODEL_LABELS, PLATFORM_LABELS } from '../strings.js';
|
||||
import { ImageHistory } from './ImageHistory.js';
|
||||
import { Img2Img } from './Img2Img.js';
|
||||
import { Inpaint } from './Inpaint.js';
|
||||
import { QueryList } from './QueryList.js';
|
||||
|
@ -44,7 +45,7 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
ONNX Web
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box sx={{ my: 4 }}>
|
||||
<Box sx={{ mx: 4, my: 4 }}>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<QueryList
|
||||
id='models'
|
||||
|
@ -92,6 +93,10 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
<Settings config={config} />
|
||||
</TabPanel>
|
||||
</TabContext>
|
||||
<Divider variant='middle' />
|
||||
<Box sx={{ mx: 4, my: 4 }}>
|
||||
<ImageHistory />
|
||||
</Box>
|
||||
</Container>
|
||||
</div>
|
||||
);
|
||||
|
|
|
@ -5,6 +5,7 @@ import { useStore } from 'zustand';
|
|||
|
||||
import { ConfigParams } from '../config.js';
|
||||
import { StateContext } from '../main.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
||||
const { useContext } = React;
|
||||
|
||||
|
@ -16,12 +17,14 @@ export function Settings(_props: SettingsProps) {
|
|||
const state = useStore(mustExist(useContext(StateContext)));
|
||||
|
||||
return <Stack spacing={2}>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<Button onClick={() => state.resetTxt2Img()}>Reset Txt2Img</Button>
|
||||
<Button onClick={() => state.resetImg2Img()}>Reset Img2Img</Button>
|
||||
<Button onClick={() => state.resetInpaint()}>Reset Inpaint</Button>
|
||||
<Button disabled>Reset All</Button>
|
||||
</Stack>
|
||||
<NumericField
|
||||
label='Image History'
|
||||
min={2}
|
||||
max={20}
|
||||
step={1}
|
||||
value={state.history.limit}
|
||||
onChange={(value) => state.setLimit(value)}
|
||||
/>
|
||||
<TextField variant='outlined' label='Default Model' value={state.defaults.model} onChange={(event) => {
|
||||
state.setDefaults({
|
||||
model: event.target.value,
|
||||
|
@ -42,5 +45,11 @@ export function Settings(_props: SettingsProps) {
|
|||
scheduler: event.target.value,
|
||||
});
|
||||
}} />
|
||||
<Stack direction='row' spacing={2}>
|
||||
<Button onClick={() => state.resetTxt2Img()}>Reset Txt2Img</Button>
|
||||
<Button onClick={() => state.resetImg2Img()}>Reset Img2Img</Button>
|
||||
<Button onClick={() => state.resetInpaint()}>Reset Inpaint</Button>
|
||||
<Button disabled>Reset All</Button>
|
||||
</Stack>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -4,15 +4,12 @@ import * as React from 'react';
|
|||
import { useMutation } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { BaseImgParams, equalResponse, paramsFromConfig } from '../api/client.js';
|
||||
import { ConfigParams } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../main.js';
|
||||
import { ImageCard } from './ImageCard.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { MutationHistory } from './MutationHistory.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
||||
const { useContext, useState } = React;
|
||||
const { useContext } = React;
|
||||
|
||||
export interface Txt2ImgProps {
|
||||
config: ConfigParams;
|
||||
|
@ -25,11 +22,16 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
const { config, model, platform } = props;
|
||||
|
||||
async function generateImage() {
|
||||
return client.txt2img({
|
||||
state.setLoading(true);
|
||||
|
||||
const output = await client.txt2img({
|
||||
...state.txt2img,
|
||||
model,
|
||||
platform,
|
||||
});
|
||||
|
||||
state.pushHistory(output);
|
||||
state.setLoading(false);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
@ -68,12 +70,6 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
/>
|
||||
</Stack>
|
||||
<Button onClick={() => generate.mutate()}>Generate</Button>
|
||||
<MutationHistory
|
||||
element={ImageCard}
|
||||
limit={4}
|
||||
isEqual={equalResponse}
|
||||
result={generate}
|
||||
/>
|
||||
</Stack>
|
||||
</Box>;
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import { QueryClient, QueryClientProvider } from 'react-query';
|
|||
import { createStore, StoreApi } from 'zustand';
|
||||
import { createJSONStorage, persist } from 'zustand/middleware';
|
||||
|
||||
import { ApiClient, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js';
|
||||
import { ApiClient, ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, makeClient, paramsFromConfig, Txt2ImgParams } from './api/client.js';
|
||||
import { OnnxWeb } from './components/OnnxWeb.js';
|
||||
import { ConfigState, loadConfig } from './config.js';
|
||||
|
||||
|
@ -27,6 +27,17 @@ interface OnnxState {
|
|||
resetTxt2Img(): void;
|
||||
resetImg2Img(): void;
|
||||
resetInpaint(): void;
|
||||
|
||||
history: {
|
||||
images: Array<ApiResponse>;
|
||||
limit: number;
|
||||
loading: boolean;
|
||||
};
|
||||
|
||||
setLimit(limit: number): void;
|
||||
setLoading(loading: boolean): void;
|
||||
setHistory(newHistory: Array<ApiResponse>): void;
|
||||
pushHistory(newImage: ApiResponse): void;
|
||||
}
|
||||
|
||||
export async function main() {
|
||||
|
@ -38,6 +49,11 @@ export async function main() {
|
|||
const defaults = paramsFromConfig(params);
|
||||
const state = createStore<OnnxState, [['zustand/persist', never]]>(persist((set) => ({
|
||||
defaults,
|
||||
history: {
|
||||
images: [],
|
||||
limit: 4,
|
||||
loading: false,
|
||||
},
|
||||
txt2img: {
|
||||
...defaults,
|
||||
height: params.height.default,
|
||||
|
@ -50,6 +66,45 @@ export async function main() {
|
|||
inpaint: {
|
||||
...defaults,
|
||||
},
|
||||
setLimit(limit) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
limit,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setLoading(loading) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
loading,
|
||||
},
|
||||
}));
|
||||
},
|
||||
pushHistory(newImage: ApiResponse) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
images: [
|
||||
newImage,
|
||||
...oldState.history.images,
|
||||
].slice(0, oldState.history.limit),
|
||||
},
|
||||
}));
|
||||
},
|
||||
setHistory(newHistory: Array<ApiResponse>) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
history: {
|
||||
...oldState.history,
|
||||
images: newHistory,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setDefaults(newParams) {
|
||||
set((oldState) => ({
|
||||
...oldState,
|
||||
|
|
Loading…
Reference in New Issue