1
0
Fork 0

feat(gui): implement image polling on the client

This commit is contained in:
Sean Sube 2023-01-12 21:12:20 -06:00
parent 55e8b800d2
commit c36daddf66
10 changed files with 82 additions and 44 deletions

View File

@ -53,7 +53,10 @@ export interface OutpaintParams extends Img2ImgParams {
}
export interface ApiResponse {
output: string;
output: {
key: string;
url: string;
};
params: Txt2ImgResponse;
}
@ -68,6 +71,8 @@ export interface ApiClient {
inpaint(params: InpaintParams): Promise<ApiResponse>;
outpaint(params: OutpaintParams): Promise<ApiResponse>;
ready(params: ApiResponse): Promise<{ready: boolean}>;
}
export const STATUS_SUCCESS = 200;
@ -94,11 +99,16 @@ export function joinPath(...parts: Array<string>): string {
}
export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> {
type LimitedResponse = Omit<ApiResponse, 'output'> & {output: string};
if (res.status === STATUS_SUCCESS) {
const data = await res.json() as ApiResponse;
const output = new URL(joinPath('output', data.output), root).toString();
const data = await res.json() as LimitedResponse;
const url = new URL(joinPath('output', data.output), root).toString();
return {
output,
output: {
key: data.output,
url,
},
params: data.params,
};
} else {
@ -229,5 +239,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
async outpaint() {
throw new NotImplementedError();
},
async ready(params: ApiResponse): Promise<{ready: boolean}> {
const path = new URL('ready', root);
path.searchParams.append('output', params.output.key);
const res = await f(path);
return await res.json() as {ready: boolean};
}
};
}

View File

@ -28,13 +28,13 @@ export function ImageCard(props: ImageCardProps) {
}
function downloadImage() {
window.open(output, '_blank');
window.open(output.url, '_blank');
}
return <Card sx={{ maxWidth: params.width }} elevation={2}>
<CardMedia sx={{ height: params.height }}
component='img'
image={output}
image={output.url}
title={params.prompt}
/>
<CardContent>

View File

@ -1,4 +1,4 @@
import { mustExist } from '@apextoaster/js-utils';
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Grid } from '@mui/material';
import { useContext } from 'react';
import * as React from 'react';
@ -17,14 +17,14 @@ export function ImageHistory() {
const children = [];
if (loading) {
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
if (doesExist(loading)) {
children.push(<LoadingCard key='loading' loading={loading} />);
}
if (history.length > 0) {
children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
children.push(...history.map((item) => <ImageCard key={item.output.key} value={item} onDelete={removeHistory} />));
} else {
if (loading === false) {
if (doesExist(loading) === false) {
children.push(<div>No results. Press Generate.</div>);
}
}

View File

@ -1,7 +1,7 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation } from 'react-query';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { ConfigParams, IMAGE_FILTER } from '../config.js';
@ -23,8 +23,6 @@ export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props;
async function uploadSource() {
setLoading(true);
const output = await client.img2img({
...params,
model,
@ -32,12 +30,14 @@ export function Img2Img(props: Img2ImgProps) {
source: mustExist(source), // TODO: show an error if this doesn't exist
});
pushHistory(output);
setLoading(false);
setLoading(output);
}
const client = mustExist(useContext(ClientContext));
const upload = useMutation(uploadSource);
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.img2img);
@ -45,8 +45,6 @@ export function Img2Img(props: Img2ImgProps) {
const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const [source, setSource] = useState<File>();

View File

@ -2,7 +2,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation } from 'react-query';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
@ -69,7 +69,6 @@ export function Inpaint(props: InpaintProps) {
async function uploadSource() {
const canvas = mustExist(canvasRef.current);
setLoading(true);
return new Promise<void>((res, rej) => {
canvas.toBlob((blob) => {
client.inpaint({
@ -79,8 +78,7 @@ export function Inpaint(props: InpaintProps) {
mask: mustExist(blob),
source: mustExist(source),
}).then((output) => {
pushHistory(output);
setLoading(false);
setLoading(output);
res();
}).catch((err) => rej(err));
});
@ -146,7 +144,10 @@ export function Inpaint(props: InpaintProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const upload = useMutation(uploadSource);
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});
// eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null);

View File

@ -1,15 +1,37 @@
import { mustExist } from '@apextoaster/js-utils';
import { Card, CardContent, CircularProgress } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useStore } from 'zustand';
import { ApiResponse } from '../api/client.js';
import { POLL_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
export interface LoadingCardProps {
height: number;
width: number;
loading: ApiResponse;
}
export function LoadingCard(props: LoadingCardProps) {
return <Card sx={{ maxWidth: props.width }}>
<CardContent sx={{ height: props.height }}>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.height }}>
const client = mustExist(React.useContext(ClientContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(mustExist(useContext(StateContext)), (state) => state.pushHistory);
const ready = useQuery('ready', () => client.ready(props.loading), {
refetchInterval: POLL_TIME,
});
React.useEffect(() => {
if (ready.status === 'success' && ready.data.ready) {
pushHistory(props.loading);
}
}, [ready.status, ready.data?.ready]);
return <Card sx={{ maxWidth: props.loading.params.width }}>
<CardContent sx={{ height: props.loading.params.height }}>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.loading.params.height }}>
<CircularProgress />
</div>
</CardContent>

View File

@ -1,7 +1,7 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation } from 'react-query';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { ConfigParams } from '../config.js';
@ -22,20 +22,20 @@ export function Txt2Img(props: Txt2ImgProps) {
const { config, model, platform } = props;
async function generateImage() {
setLoading(true);
const output = await client.txt2img({
...params,
model,
platform,
});
pushHistory(output);
setLoading(false);
setLoading(output);
}
const client = mustExist(useContext(ClientContext));
const generate = useMutation(generateImage);
const query = useQueryClient();
const generate = useMutation(generateImage, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.txt2img);

View File

@ -43,7 +43,8 @@ export const DEFAULT_BRUSH = {
size: 8,
};
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
export const STALE_TIME = 3_000;
export const STALE_TIME = 300_000; // 5 minutes
export const POLL_TIME = 5_000; // 5 seconds
export async function loadConfig(): Promise<Config> {
const configPath = new URL('./config.json', window.origin);

View File

@ -41,12 +41,8 @@ export async function main() {
...createDefaultSlice(...slice),
}), {
name: 'onnx-web',
partialize: (oldState) => ({
...oldState,
loading: false,
}),
storage: createJSONStorage(() => localStorage),
version: 2,
version: 3,
}));
// prep react-query client

View File

@ -39,12 +39,12 @@ interface InpaintSlice {
interface HistorySlice {
history: Array<ApiResponse>;
limit: number;
loading: boolean;
loading: Maybe<ApiResponse>;
pushHistory(image: ApiResponse): void;
removeHistory(image: ApiResponse): void;
setLimit(limit: number): void;
setLoading(loading: boolean): void;
setLoading(image: Maybe<ApiResponse>): void;
}
interface DefaultSlice {
@ -130,7 +130,8 @@ export function createStateSlices(base: ConfigParams) {
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
history: [],
limit: 4,
loading: false,
// eslint-disable-next-line no-null/no-null
loading: null,
pushHistory(image) {
set((prev) => ({
...prev,
@ -138,6 +139,8 @@ export function createStateSlices(base: ConfigParams) {
image,
...prev.history,
],
// eslint-disable-next-line no-null/no-null
loading: null,
}));
},
removeHistory(image) {