feat(gui): implement image polling on the client
This commit is contained in:
parent
41935667c4
commit
1dbe275f41
|
@ -53,7 +53,10 @@ export interface OutpaintParams extends Img2ImgParams {
|
|||
}
|
||||
|
||||
export interface ApiResponse {
|
||||
output: string;
|
||||
output: {
|
||||
key: string;
|
||||
url: string;
|
||||
};
|
||||
params: Txt2ImgResponse;
|
||||
}
|
||||
|
||||
|
@ -68,6 +71,8 @@ export interface ApiClient {
|
|||
|
||||
inpaint(params: InpaintParams): Promise<ApiResponse>;
|
||||
outpaint(params: OutpaintParams): Promise<ApiResponse>;
|
||||
|
||||
ready(params: ApiResponse): Promise<{ready: boolean}>;
|
||||
}
|
||||
|
||||
export const STATUS_SUCCESS = 200;
|
||||
|
@ -94,11 +99,16 @@ export function joinPath(...parts: Array<string>): string {
|
|||
}
|
||||
|
||||
export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> {
|
||||
type LimitedResponse = Omit<ApiResponse, 'output'> & {output: string};
|
||||
|
||||
if (res.status === STATUS_SUCCESS) {
|
||||
const data = await res.json() as ApiResponse;
|
||||
const output = new URL(joinPath('output', data.output), root).toString();
|
||||
const data = await res.json() as LimitedResponse;
|
||||
const url = new URL(joinPath('output', data.output), root).toString();
|
||||
return {
|
||||
output,
|
||||
output: {
|
||||
key: data.output,
|
||||
url,
|
||||
},
|
||||
params: data.params,
|
||||
};
|
||||
} else {
|
||||
|
@ -229,5 +239,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
async outpaint() {
|
||||
throw new NotImplementedError();
|
||||
},
|
||||
async ready(params: ApiResponse): Promise<{ready: boolean}> {
|
||||
const path = new URL('ready', root);
|
||||
path.searchParams.append('output', params.output.key);
|
||||
|
||||
const res = await f(path);
|
||||
return await res.json() as {ready: boolean};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -28,13 +28,13 @@ export function ImageCard(props: ImageCardProps) {
|
|||
}
|
||||
|
||||
function downloadImage() {
|
||||
window.open(output, '_blank');
|
||||
window.open(output.url, '_blank');
|
||||
}
|
||||
|
||||
return <Card sx={{ maxWidth: params.width }} elevation={2}>
|
||||
<CardMedia sx={{ height: params.height }}
|
||||
component='img'
|
||||
image={output}
|
||||
image={output.url}
|
||||
title={params.prompt}
|
||||
/>
|
||||
<CardContent>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { Grid } from '@mui/material';
|
||||
import { useContext } from 'react';
|
||||
import * as React from 'react';
|
||||
|
@ -17,14 +17,14 @@ export function ImageHistory() {
|
|||
|
||||
const children = [];
|
||||
|
||||
if (loading) {
|
||||
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
|
||||
if (doesExist(loading)) {
|
||||
children.push(<LoadingCard key='loading' loading={loading} />);
|
||||
}
|
||||
|
||||
if (history.length > 0) {
|
||||
children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
|
||||
children.push(...history.map((item) => <ImageCard key={item.output.key} value={item} onDelete={removeHistory} />));
|
||||
} else {
|
||||
if (loading === false) {
|
||||
if (doesExist(loading) === false) {
|
||||
children.push(<div>No results. Press Generate.</div>);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation } from 'react-query';
|
||||
import { useMutation, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams, IMAGE_FILTER } from '../config.js';
|
||||
|
@ -23,8 +23,6 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
const { config, model, platform } = props;
|
||||
|
||||
async function uploadSource() {
|
||||
setLoading(true);
|
||||
|
||||
const output = await client.img2img({
|
||||
...params,
|
||||
model,
|
||||
|
@ -32,12 +30,14 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
source: mustExist(source), // TODO: show an error if this doesn't exist
|
||||
});
|
||||
|
||||
pushHistory(output);
|
||||
setLoading(false);
|
||||
setLoading(output);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const upload = useMutation(uploadSource);
|
||||
const query = useQueryClient();
|
||||
const upload = useMutation(uploadSource, {
|
||||
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
|
||||
});
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const params = useStore(state, (s) => s.img2img);
|
||||
|
@ -45,8 +45,6 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
const setImg2Img = useStore(state, (s) => s.setImg2Img);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setLoading = useStore(state, (s) => s.setLoading);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const pushHistory = useStore(state, (s) => s.pushHistory);
|
||||
|
||||
const [source, setSource] = useState<File>();
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils';
|
|||
import { FormatColorFill, Gradient } from '@mui/icons-material';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation } from 'react-query';
|
||||
import { useMutation, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
|
||||
|
@ -69,7 +69,6 @@ export function Inpaint(props: InpaintProps) {
|
|||
|
||||
async function uploadSource() {
|
||||
const canvas = mustExist(canvasRef.current);
|
||||
setLoading(true);
|
||||
return new Promise<void>((res, rej) => {
|
||||
canvas.toBlob((blob) => {
|
||||
client.inpaint({
|
||||
|
@ -79,8 +78,7 @@ export function Inpaint(props: InpaintProps) {
|
|||
mask: mustExist(blob),
|
||||
source: mustExist(source),
|
||||
}).then((output) => {
|
||||
pushHistory(output);
|
||||
setLoading(false);
|
||||
setLoading(output);
|
||||
res();
|
||||
}).catch((err) => rej(err));
|
||||
});
|
||||
|
@ -146,7 +144,10 @@ export function Inpaint(props: InpaintProps) {
|
|||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const pushHistory = useStore(state, (s) => s.pushHistory);
|
||||
|
||||
const upload = useMutation(uploadSource);
|
||||
const query = useQueryClient();
|
||||
const upload = useMutation(uploadSource, {
|
||||
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
|
||||
});
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||
|
||||
|
|
|
@ -1,15 +1,37 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Card, CardContent, CircularProgress } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useQuery } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ApiResponse } from '../api/client.js';
|
||||
import { POLL_TIME } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../state.js';
|
||||
|
||||
export interface LoadingCardProps {
|
||||
height: number;
|
||||
width: number;
|
||||
loading: ApiResponse;
|
||||
}
|
||||
|
||||
export function LoadingCard(props: LoadingCardProps) {
|
||||
return <Card sx={{ maxWidth: props.width }}>
|
||||
<CardContent sx={{ height: props.height }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.height }}>
|
||||
const client = mustExist(React.useContext(ClientContext));
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const pushHistory = useStore(mustExist(useContext(StateContext)), (state) => state.pushHistory);
|
||||
|
||||
const ready = useQuery('ready', () => client.ready(props.loading), {
|
||||
refetchInterval: POLL_TIME,
|
||||
});
|
||||
|
||||
React.useEffect(() => {
|
||||
if (ready.status === 'success' && ready.data.ready) {
|
||||
pushHistory(props.loading);
|
||||
}
|
||||
}, [ready.status, ready.data?.ready]);
|
||||
|
||||
return <Card sx={{ maxWidth: props.loading.params.width }}>
|
||||
<CardContent sx={{ height: props.loading.params.height }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.loading.params.height }}>
|
||||
<CircularProgress />
|
||||
</div>
|
||||
</CardContent>
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation } from 'react-query';
|
||||
import { useMutation, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams } from '../config.js';
|
||||
|
@ -22,20 +22,20 @@ export function Txt2Img(props: Txt2ImgProps) {
|
|||
const { config, model, platform } = props;
|
||||
|
||||
async function generateImage() {
|
||||
setLoading(true);
|
||||
|
||||
const output = await client.txt2img({
|
||||
...params,
|
||||
model,
|
||||
platform,
|
||||
});
|
||||
|
||||
pushHistory(output);
|
||||
setLoading(false);
|
||||
setLoading(output);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const generate = useMutation(generateImage);
|
||||
const query = useQueryClient();
|
||||
const generate = useMutation(generateImage, {
|
||||
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
|
||||
});
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const params = useStore(state, (s) => s.txt2img);
|
||||
|
|
|
@ -43,7 +43,8 @@ export const DEFAULT_BRUSH = {
|
|||
size: 8,
|
||||
};
|
||||
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
|
||||
export const STALE_TIME = 3_000;
|
||||
export const STALE_TIME = 300_000; // 5 minutes
|
||||
export const POLL_TIME = 5_000; // 5 seconds
|
||||
|
||||
export async function loadConfig(): Promise<Config> {
|
||||
const configPath = new URL('./config.json', window.origin);
|
||||
|
|
|
@ -41,12 +41,8 @@ export async function main() {
|
|||
...createDefaultSlice(...slice),
|
||||
}), {
|
||||
name: 'onnx-web',
|
||||
partialize: (oldState) => ({
|
||||
...oldState,
|
||||
loading: false,
|
||||
}),
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
version: 2,
|
||||
version: 3,
|
||||
}));
|
||||
|
||||
// prep react-query client
|
||||
|
|
|
@ -39,12 +39,12 @@ interface InpaintSlice {
|
|||
interface HistorySlice {
|
||||
history: Array<ApiResponse>;
|
||||
limit: number;
|
||||
loading: boolean;
|
||||
loading: Maybe<ApiResponse>;
|
||||
|
||||
pushHistory(image: ApiResponse): void;
|
||||
removeHistory(image: ApiResponse): void;
|
||||
setLimit(limit: number): void;
|
||||
setLoading(loading: boolean): void;
|
||||
setLoading(image: Maybe<ApiResponse>): void;
|
||||
}
|
||||
|
||||
interface DefaultSlice {
|
||||
|
@ -130,7 +130,8 @@ export function createStateSlices(base: ConfigParams) {
|
|||
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
|
||||
history: [],
|
||||
limit: 4,
|
||||
loading: false,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
loading: null,
|
||||
pushHistory(image) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
|
@ -138,6 +139,8 @@ export function createStateSlices(base: ConfigParams) {
|
|||
image,
|
||||
...prev.history,
|
||||
],
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
loading: null,
|
||||
}));
|
||||
},
|
||||
removeHistory(image) {
|
||||
|
|
Loading…
Reference in New Issue