feat(gui): implement image polling on the client
This commit is contained in:
parent
55e8b800d2
commit
c36daddf66
|
@ -53,7 +53,10 @@ export interface OutpaintParams extends Img2ImgParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ApiResponse {
|
export interface ApiResponse {
|
||||||
output: string;
|
output: {
|
||||||
|
key: string;
|
||||||
|
url: string;
|
||||||
|
};
|
||||||
params: Txt2ImgResponse;
|
params: Txt2ImgResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,6 +71,8 @@ export interface ApiClient {
|
||||||
|
|
||||||
inpaint(params: InpaintParams): Promise<ApiResponse>;
|
inpaint(params: InpaintParams): Promise<ApiResponse>;
|
||||||
outpaint(params: OutpaintParams): Promise<ApiResponse>;
|
outpaint(params: OutpaintParams): Promise<ApiResponse>;
|
||||||
|
|
||||||
|
ready(params: ApiResponse): Promise<{ready: boolean}>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const STATUS_SUCCESS = 200;
|
export const STATUS_SUCCESS = 200;
|
||||||
|
@ -94,11 +99,16 @@ export function joinPath(...parts: Array<string>): string {
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> {
|
export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> {
|
||||||
|
type LimitedResponse = Omit<ApiResponse, 'output'> & {output: string};
|
||||||
|
|
||||||
if (res.status === STATUS_SUCCESS) {
|
if (res.status === STATUS_SUCCESS) {
|
||||||
const data = await res.json() as ApiResponse;
|
const data = await res.json() as LimitedResponse;
|
||||||
const output = new URL(joinPath('output', data.output), root).toString();
|
const url = new URL(joinPath('output', data.output), root).toString();
|
||||||
return {
|
return {
|
||||||
output,
|
output: {
|
||||||
|
key: data.output,
|
||||||
|
url,
|
||||||
|
},
|
||||||
params: data.params,
|
params: data.params,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
|
@ -229,5 +239,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
async outpaint() {
|
async outpaint() {
|
||||||
throw new NotImplementedError();
|
throw new NotImplementedError();
|
||||||
},
|
},
|
||||||
|
async ready(params: ApiResponse): Promise<{ready: boolean}> {
|
||||||
|
const path = new URL('ready', root);
|
||||||
|
path.searchParams.append('output', params.output.key);
|
||||||
|
|
||||||
|
const res = await f(path);
|
||||||
|
return await res.json() as {ready: boolean};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,13 +28,13 @@ export function ImageCard(props: ImageCardProps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function downloadImage() {
|
function downloadImage() {
|
||||||
window.open(output, '_blank');
|
window.open(output.url, '_blank');
|
||||||
}
|
}
|
||||||
|
|
||||||
return <Card sx={{ maxWidth: params.width }} elevation={2}>
|
return <Card sx={{ maxWidth: params.width }} elevation={2}>
|
||||||
<CardMedia sx={{ height: params.height }}
|
<CardMedia sx={{ height: params.height }}
|
||||||
component='img'
|
component='img'
|
||||||
image={output}
|
image={output.url}
|
||||||
title={params.prompt}
|
title={params.prompt}
|
||||||
/>
|
/>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||||
import { Grid } from '@mui/material';
|
import { Grid } from '@mui/material';
|
||||||
import { useContext } from 'react';
|
import { useContext } from 'react';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
|
@ -17,14 +17,14 @@ export function ImageHistory() {
|
||||||
|
|
||||||
const children = [];
|
const children = [];
|
||||||
|
|
||||||
if (loading) {
|
if (doesExist(loading)) {
|
||||||
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
|
children.push(<LoadingCard key='loading' loading={loading} />);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (history.length > 0) {
|
if (history.length > 0) {
|
||||||
children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
|
children.push(...history.map((item) => <ImageCard key={item.output.key} value={item} onDelete={removeHistory} />));
|
||||||
} else {
|
} else {
|
||||||
if (loading === false) {
|
if (doesExist(loading) === false) {
|
||||||
children.push(<div>No results. Press Generate.</div>);
|
children.push(<div>No results. Press Generate.</div>);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
import { Box, Button, Stack } from '@mui/material';
|
import { Box, Button, Stack } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useMutation } from 'react-query';
|
import { useMutation, useQueryClient } from 'react-query';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigParams, IMAGE_FILTER } from '../config.js';
|
import { ConfigParams, IMAGE_FILTER } from '../config.js';
|
||||||
|
@ -23,8 +23,6 @@ export function Img2Img(props: Img2ImgProps) {
|
||||||
const { config, model, platform } = props;
|
const { config, model, platform } = props;
|
||||||
|
|
||||||
async function uploadSource() {
|
async function uploadSource() {
|
||||||
setLoading(true);
|
|
||||||
|
|
||||||
const output = await client.img2img({
|
const output = await client.img2img({
|
||||||
...params,
|
...params,
|
||||||
model,
|
model,
|
||||||
|
@ -32,12 +30,14 @@ export function Img2Img(props: Img2ImgProps) {
|
||||||
source: mustExist(source), // TODO: show an error if this doesn't exist
|
source: mustExist(source), // TODO: show an error if this doesn't exist
|
||||||
});
|
});
|
||||||
|
|
||||||
pushHistory(output);
|
setLoading(output);
|
||||||
setLoading(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
const upload = useMutation(uploadSource);
|
const query = useQueryClient();
|
||||||
|
const upload = useMutation(uploadSource, {
|
||||||
|
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
|
||||||
|
});
|
||||||
|
|
||||||
const state = mustExist(useContext(StateContext));
|
const state = mustExist(useContext(StateContext));
|
||||||
const params = useStore(state, (s) => s.img2img);
|
const params = useStore(state, (s) => s.img2img);
|
||||||
|
@ -45,8 +45,6 @@ export function Img2Img(props: Img2ImgProps) {
|
||||||
const setImg2Img = useStore(state, (s) => s.setImg2Img);
|
const setImg2Img = useStore(state, (s) => s.setImg2Img);
|
||||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||||
const setLoading = useStore(state, (s) => s.setLoading);
|
const setLoading = useStore(state, (s) => s.setLoading);
|
||||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
|
||||||
const pushHistory = useStore(state, (s) => s.pushHistory);
|
|
||||||
|
|
||||||
const [source, setSource] = useState<File>();
|
const [source, setSource] = useState<File>();
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||||
import { FormatColorFill, Gradient } from '@mui/icons-material';
|
import { FormatColorFill, Gradient } from '@mui/icons-material';
|
||||||
import { Box, Button, Stack } from '@mui/material';
|
import { Box, Button, Stack } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useMutation } from 'react-query';
|
import { useMutation, useQueryClient } from 'react-query';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
|
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
|
||||||
|
@ -69,7 +69,6 @@ export function Inpaint(props: InpaintProps) {
|
||||||
|
|
||||||
async function uploadSource() {
|
async function uploadSource() {
|
||||||
const canvas = mustExist(canvasRef.current);
|
const canvas = mustExist(canvasRef.current);
|
||||||
setLoading(true);
|
|
||||||
return new Promise<void>((res, rej) => {
|
return new Promise<void>((res, rej) => {
|
||||||
canvas.toBlob((blob) => {
|
canvas.toBlob((blob) => {
|
||||||
client.inpaint({
|
client.inpaint({
|
||||||
|
@ -79,8 +78,7 @@ export function Inpaint(props: InpaintProps) {
|
||||||
mask: mustExist(blob),
|
mask: mustExist(blob),
|
||||||
source: mustExist(source),
|
source: mustExist(source),
|
||||||
}).then((output) => {
|
}).then((output) => {
|
||||||
pushHistory(output);
|
setLoading(output);
|
||||||
setLoading(false);
|
|
||||||
res();
|
res();
|
||||||
}).catch((err) => rej(err));
|
}).catch((err) => rej(err));
|
||||||
});
|
});
|
||||||
|
@ -146,7 +144,10 @@ export function Inpaint(props: InpaintProps) {
|
||||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||||
const pushHistory = useStore(state, (s) => s.pushHistory);
|
const pushHistory = useStore(state, (s) => s.pushHistory);
|
||||||
|
|
||||||
const upload = useMutation(uploadSource);
|
const query = useQueryClient();
|
||||||
|
const upload = useMutation(uploadSource, {
|
||||||
|
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
|
||||||
|
});
|
||||||
// eslint-disable-next-line no-null/no-null
|
// eslint-disable-next-line no-null/no-null
|
||||||
const canvasRef = useRef<HTMLCanvasElement>(null);
|
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,37 @@
|
||||||
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
import { Card, CardContent, CircularProgress } from '@mui/material';
|
import { Card, CardContent, CircularProgress } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
|
import { useContext } from 'react';
|
||||||
|
import { useQuery } from 'react-query';
|
||||||
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
|
import { ApiResponse } from '../api/client.js';
|
||||||
|
import { POLL_TIME } from '../config.js';
|
||||||
|
import { ClientContext, StateContext } from '../state.js';
|
||||||
|
|
||||||
export interface LoadingCardProps {
|
export interface LoadingCardProps {
|
||||||
height: number;
|
loading: ApiResponse;
|
||||||
width: number;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function LoadingCard(props: LoadingCardProps) {
|
export function LoadingCard(props: LoadingCardProps) {
|
||||||
return <Card sx={{ maxWidth: props.width }}>
|
const client = mustExist(React.useContext(ClientContext));
|
||||||
<CardContent sx={{ height: props.height }}>
|
|
||||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.height }}>
|
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||||
|
const pushHistory = useStore(mustExist(useContext(StateContext)), (state) => state.pushHistory);
|
||||||
|
|
||||||
|
const ready = useQuery('ready', () => client.ready(props.loading), {
|
||||||
|
refetchInterval: POLL_TIME,
|
||||||
|
});
|
||||||
|
|
||||||
|
React.useEffect(() => {
|
||||||
|
if (ready.status === 'success' && ready.data.ready) {
|
||||||
|
pushHistory(props.loading);
|
||||||
|
}
|
||||||
|
}, [ready.status, ready.data?.ready]);
|
||||||
|
|
||||||
|
return <Card sx={{ maxWidth: props.loading.params.width }}>
|
||||||
|
<CardContent sx={{ height: props.loading.params.height }}>
|
||||||
|
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.loading.params.height }}>
|
||||||
<CircularProgress />
|
<CircularProgress />
|
||||||
</div>
|
</div>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
import { Box, Button, Stack } from '@mui/material';
|
import { Box, Button, Stack } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useMutation } from 'react-query';
|
import { useMutation, useQueryClient } from 'react-query';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigParams } from '../config.js';
|
import { ConfigParams } from '../config.js';
|
||||||
|
@ -22,20 +22,20 @@ export function Txt2Img(props: Txt2ImgProps) {
|
||||||
const { config, model, platform } = props;
|
const { config, model, platform } = props;
|
||||||
|
|
||||||
async function generateImage() {
|
async function generateImage() {
|
||||||
setLoading(true);
|
|
||||||
|
|
||||||
const output = await client.txt2img({
|
const output = await client.txt2img({
|
||||||
...params,
|
...params,
|
||||||
model,
|
model,
|
||||||
platform,
|
platform,
|
||||||
});
|
});
|
||||||
|
|
||||||
pushHistory(output);
|
setLoading(output);
|
||||||
setLoading(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
const generate = useMutation(generateImage);
|
const query = useQueryClient();
|
||||||
|
const generate = useMutation(generateImage, {
|
||||||
|
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
|
||||||
|
});
|
||||||
|
|
||||||
const state = mustExist(useContext(StateContext));
|
const state = mustExist(useContext(StateContext));
|
||||||
const params = useStore(state, (s) => s.txt2img);
|
const params = useStore(state, (s) => s.txt2img);
|
||||||
|
|
|
@ -43,7 +43,8 @@ export const DEFAULT_BRUSH = {
|
||||||
size: 8,
|
size: 8,
|
||||||
};
|
};
|
||||||
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
|
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
|
||||||
export const STALE_TIME = 3_000;
|
export const STALE_TIME = 300_000; // 5 minutes
|
||||||
|
export const POLL_TIME = 5_000; // 5 seconds
|
||||||
|
|
||||||
export async function loadConfig(): Promise<Config> {
|
export async function loadConfig(): Promise<Config> {
|
||||||
const configPath = new URL('./config.json', window.origin);
|
const configPath = new URL('./config.json', window.origin);
|
||||||
|
|
|
@ -41,12 +41,8 @@ export async function main() {
|
||||||
...createDefaultSlice(...slice),
|
...createDefaultSlice(...slice),
|
||||||
}), {
|
}), {
|
||||||
name: 'onnx-web',
|
name: 'onnx-web',
|
||||||
partialize: (oldState) => ({
|
|
||||||
...oldState,
|
|
||||||
loading: false,
|
|
||||||
}),
|
|
||||||
storage: createJSONStorage(() => localStorage),
|
storage: createJSONStorage(() => localStorage),
|
||||||
version: 2,
|
version: 3,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// prep react-query client
|
// prep react-query client
|
||||||
|
|
|
@ -39,12 +39,12 @@ interface InpaintSlice {
|
||||||
interface HistorySlice {
|
interface HistorySlice {
|
||||||
history: Array<ApiResponse>;
|
history: Array<ApiResponse>;
|
||||||
limit: number;
|
limit: number;
|
||||||
loading: boolean;
|
loading: Maybe<ApiResponse>;
|
||||||
|
|
||||||
pushHistory(image: ApiResponse): void;
|
pushHistory(image: ApiResponse): void;
|
||||||
removeHistory(image: ApiResponse): void;
|
removeHistory(image: ApiResponse): void;
|
||||||
setLimit(limit: number): void;
|
setLimit(limit: number): void;
|
||||||
setLoading(loading: boolean): void;
|
setLoading(image: Maybe<ApiResponse>): void;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface DefaultSlice {
|
interface DefaultSlice {
|
||||||
|
@ -130,7 +130,8 @@ export function createStateSlices(base: ConfigParams) {
|
||||||
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
|
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
|
||||||
history: [],
|
history: [],
|
||||||
limit: 4,
|
limit: 4,
|
||||||
loading: false,
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
loading: null,
|
||||||
pushHistory(image) {
|
pushHistory(image) {
|
||||||
set((prev) => ({
|
set((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
|
@ -138,6 +139,8 @@ export function createStateSlices(base: ConfigParams) {
|
||||||
image,
|
image,
|
||||||
...prev.history,
|
...prev.history,
|
||||||
],
|
],
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
loading: null,
|
||||||
}));
|
}));
|
||||||
},
|
},
|
||||||
removeHistory(image) {
|
removeHistory(image) {
|
||||||
|
|
Loading…
Reference in New Issue