1
0
Fork 0

feat(gui): implement image polling on the client

This commit is contained in:
Sean Sube 2023-01-12 21:12:20 -06:00
parent 41935667c4
commit 1dbe275f41
10 changed files with 82 additions and 44 deletions

View File

@ -53,7 +53,10 @@ export interface OutpaintParams extends Img2ImgParams {
} }
export interface ApiResponse { export interface ApiResponse {
output: string; output: {
key: string;
url: string;
};
params: Txt2ImgResponse; params: Txt2ImgResponse;
} }
@ -68,6 +71,8 @@ export interface ApiClient {
inpaint(params: InpaintParams): Promise<ApiResponse>; inpaint(params: InpaintParams): Promise<ApiResponse>;
outpaint(params: OutpaintParams): Promise<ApiResponse>; outpaint(params: OutpaintParams): Promise<ApiResponse>;
ready(params: ApiResponse): Promise<{ready: boolean}>;
} }
export const STATUS_SUCCESS = 200; export const STATUS_SUCCESS = 200;
@ -94,11 +99,16 @@ export function joinPath(...parts: Array<string>): string {
} }
export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> { export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> {
type LimitedResponse = Omit<ApiResponse, 'output'> & {output: string};
if (res.status === STATUS_SUCCESS) { if (res.status === STATUS_SUCCESS) {
const data = await res.json() as ApiResponse; const data = await res.json() as LimitedResponse;
const output = new URL(joinPath('output', data.output), root).toString(); const url = new URL(joinPath('output', data.output), root).toString();
return { return {
output, output: {
key: data.output,
url,
},
params: data.params, params: data.params,
}; };
} else { } else {
@ -229,5 +239,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
async outpaint() { async outpaint() {
throw new NotImplementedError(); throw new NotImplementedError();
}, },
async ready(params: ApiResponse): Promise<{ready: boolean}> {
const path = new URL('ready', root);
path.searchParams.append('output', params.output.key);
const res = await f(path);
return await res.json() as {ready: boolean};
}
}; };
} }

View File

@ -28,13 +28,13 @@ export function ImageCard(props: ImageCardProps) {
} }
function downloadImage() { function downloadImage() {
window.open(output, '_blank'); window.open(output.url, '_blank');
} }
return <Card sx={{ maxWidth: params.width }} elevation={2}> return <Card sx={{ maxWidth: params.width }} elevation={2}>
<CardMedia sx={{ height: params.height }} <CardMedia sx={{ height: params.height }}
component='img' component='img'
image={output} image={output.url}
title={params.prompt} title={params.prompt}
/> />
<CardContent> <CardContent>

View File

@ -1,4 +1,4 @@
import { mustExist } from '@apextoaster/js-utils'; import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Grid } from '@mui/material'; import { Grid } from '@mui/material';
import { useContext } from 'react'; import { useContext } from 'react';
import * as React from 'react'; import * as React from 'react';
@ -17,14 +17,14 @@ export function ImageHistory() {
const children = []; const children = [];
if (loading) { if (doesExist(loading)) {
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config children.push(<LoadingCard key='loading' loading={loading} />);
} }
if (history.length > 0) { if (history.length > 0) {
children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />)); children.push(...history.map((item) => <ImageCard key={item.output.key} value={item} onDelete={removeHistory} />));
} else { } else {
if (loading === false) { if (doesExist(loading) === false) {
children.push(<div>No results. Press Generate.</div>); children.push(<div>No results. Press Generate.</div>);
} }
} }

View File

@ -1,7 +1,7 @@
import { mustExist } from '@apextoaster/js-utils'; import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material'; import { Box, Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useMutation } from 'react-query'; import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams, IMAGE_FILTER } from '../config.js'; import { ConfigParams, IMAGE_FILTER } from '../config.js';
@ -23,8 +23,6 @@ export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props; const { config, model, platform } = props;
async function uploadSource() { async function uploadSource() {
setLoading(true);
const output = await client.img2img({ const output = await client.img2img({
...params, ...params,
model, model,
@ -32,12 +30,14 @@ export function Img2Img(props: Img2ImgProps) {
source: mustExist(source), // TODO: show an error if this doesn't exist source: mustExist(source), // TODO: show an error if this doesn't exist
}); });
pushHistory(output); setLoading(output);
setLoading(false);
} }
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
const upload = useMutation(uploadSource); const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});
const state = mustExist(useContext(StateContext)); const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.img2img); const params = useStore(state, (s) => s.img2img);
@ -45,8 +45,6 @@ export function Img2Img(props: Img2ImgProps) {
const setImg2Img = useStore(state, (s) => s.setImg2Img); const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading); const setLoading = useStore(state, (s) => s.setLoading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const [source, setSource] = useState<File>(); const [source, setSource] = useState<File>();

View File

@ -2,7 +2,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material'; import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Box, Button, Stack } from '@mui/material'; import { Box, Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useMutation } from 'react-query'; import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js'; import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
@ -69,7 +69,6 @@ export function Inpaint(props: InpaintProps) {
async function uploadSource() { async function uploadSource() {
const canvas = mustExist(canvasRef.current); const canvas = mustExist(canvasRef.current);
setLoading(true);
return new Promise<void>((res, rej) => { return new Promise<void>((res, rej) => {
canvas.toBlob((blob) => { canvas.toBlob((blob) => {
client.inpaint({ client.inpaint({
@ -79,8 +78,7 @@ export function Inpaint(props: InpaintProps) {
mask: mustExist(blob), mask: mustExist(blob),
source: mustExist(source), source: mustExist(source),
}).then((output) => { }).then((output) => {
pushHistory(output); setLoading(output);
setLoading(false);
res(); res();
}).catch((err) => rej(err)); }).catch((err) => rej(err));
}); });
@ -146,7 +144,10 @@ export function Inpaint(props: InpaintProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory); const pushHistory = useStore(state, (s) => s.pushHistory);
const upload = useMutation(uploadSource); const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});
// eslint-disable-next-line no-null/no-null // eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null); const canvasRef = useRef<HTMLCanvasElement>(null);

View File

@ -1,15 +1,37 @@
import { mustExist } from '@apextoaster/js-utils';
import { Card, CardContent, CircularProgress } from '@mui/material'; import { Card, CardContent, CircularProgress } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useStore } from 'zustand';
import { ApiResponse } from '../api/client.js';
import { POLL_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
export interface LoadingCardProps { export interface LoadingCardProps {
height: number; loading: ApiResponse;
width: number;
} }
export function LoadingCard(props: LoadingCardProps) { export function LoadingCard(props: LoadingCardProps) {
return <Card sx={{ maxWidth: props.width }}> const client = mustExist(React.useContext(ClientContext));
<CardContent sx={{ height: props.height }}>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.height }}> // eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(mustExist(useContext(StateContext)), (state) => state.pushHistory);
const ready = useQuery('ready', () => client.ready(props.loading), {
refetchInterval: POLL_TIME,
});
React.useEffect(() => {
if (ready.status === 'success' && ready.data.ready) {
pushHistory(props.loading);
}
}, [ready.status, ready.data?.ready]);
return <Card sx={{ maxWidth: props.loading.params.width }}>
<CardContent sx={{ height: props.loading.params.height }}>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.loading.params.height }}>
<CircularProgress /> <CircularProgress />
</div> </div>
</CardContent> </CardContent>

View File

@ -1,7 +1,7 @@
import { mustExist } from '@apextoaster/js-utils'; import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material'; import { Box, Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useMutation } from 'react-query'; import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams } from '../config.js'; import { ConfigParams } from '../config.js';
@ -22,20 +22,20 @@ export function Txt2Img(props: Txt2ImgProps) {
const { config, model, platform } = props; const { config, model, platform } = props;
async function generateImage() { async function generateImage() {
setLoading(true);
const output = await client.txt2img({ const output = await client.txt2img({
...params, ...params,
model, model,
platform, platform,
}); });
pushHistory(output); setLoading(output);
setLoading(false);
} }
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
const generate = useMutation(generateImage); const query = useQueryClient();
const generate = useMutation(generateImage, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});
const state = mustExist(useContext(StateContext)); const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.txt2img); const params = useStore(state, (s) => s.txt2img);

View File

@ -43,7 +43,8 @@ export const DEFAULT_BRUSH = {
size: 8, size: 8,
}; };
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png'; export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
export const STALE_TIME = 3_000; export const STALE_TIME = 300_000; // 5 minutes
export const POLL_TIME = 5_000; // 5 seconds
export async function loadConfig(): Promise<Config> { export async function loadConfig(): Promise<Config> {
const configPath = new URL('./config.json', window.origin); const configPath = new URL('./config.json', window.origin);

View File

@ -41,12 +41,8 @@ export async function main() {
...createDefaultSlice(...slice), ...createDefaultSlice(...slice),
}), { }), {
name: 'onnx-web', name: 'onnx-web',
partialize: (oldState) => ({
...oldState,
loading: false,
}),
storage: createJSONStorage(() => localStorage), storage: createJSONStorage(() => localStorage),
version: 2, version: 3,
})); }));
// prep react-query client // prep react-query client

View File

@ -39,12 +39,12 @@ interface InpaintSlice {
interface HistorySlice { interface HistorySlice {
history: Array<ApiResponse>; history: Array<ApiResponse>;
limit: number; limit: number;
loading: boolean; loading: Maybe<ApiResponse>;
pushHistory(image: ApiResponse): void; pushHistory(image: ApiResponse): void;
removeHistory(image: ApiResponse): void; removeHistory(image: ApiResponse): void;
setLimit(limit: number): void; setLimit(limit: number): void;
setLoading(loading: boolean): void; setLoading(image: Maybe<ApiResponse>): void;
} }
interface DefaultSlice { interface DefaultSlice {
@ -130,7 +130,8 @@ export function createStateSlices(base: ConfigParams) {
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({ const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
history: [], history: [],
limit: 4, limit: 4,
loading: false, // eslint-disable-next-line no-null/no-null
loading: null,
pushHistory(image) { pushHistory(image) {
set((prev) => ({ set((prev) => ({
...prev, ...prev,
@ -138,6 +139,8 @@ export function createStateSlices(base: ConfigParams) {
image, image,
...prev.history, ...prev.history,
], ],
// eslint-disable-next-line no-null/no-null
loading: null,
})); }));
}, },
removeHistory(image) { removeHistory(image) {