From c36daddf66f19dbeab58034b8b7c35b85eca5706 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 12 Jan 2023 21:12:20 -0600 Subject: [PATCH] feat(gui): implement image polling on the client --- gui/src/api/client.ts | 25 ++++++++++++++++++---- gui/src/components/ImageCard.tsx | 4 ++-- gui/src/components/ImageHistory.tsx | 10 ++++----- gui/src/components/Img2Img.tsx | 14 ++++++------- gui/src/components/Inpaint.tsx | 11 +++++----- gui/src/components/LoadingCard.tsx | 32 ++++++++++++++++++++++++----- gui/src/components/Txt2Img.tsx | 12 +++++------ gui/src/config.ts | 3 ++- gui/src/main.tsx | 6 +----- gui/src/state.ts | 9 +++++--- 10 files changed, 82 insertions(+), 44 deletions(-) diff --git a/gui/src/api/client.ts b/gui/src/api/client.ts index 9eb01d35..6b8b108d 100644 --- a/gui/src/api/client.ts +++ b/gui/src/api/client.ts @@ -53,7 +53,10 @@ export interface OutpaintParams extends Img2ImgParams { } export interface ApiResponse { - output: string; + output: { + key: string; + url: string; + }; params: Txt2ImgResponse; } @@ -68,6 +71,8 @@ export interface ApiClient { inpaint(params: InpaintParams): Promise; outpaint(params: OutpaintParams): Promise; + + ready(params: ApiResponse): Promise<{ready: boolean}>; } export const STATUS_SUCCESS = 200; @@ -94,11 +99,16 @@ export function joinPath(...parts: Array): string { } export async function imageFromResponse(root: string, res: Response): Promise { + type LimitedResponse = Omit & {output: string}; + if (res.status === STATUS_SUCCESS) { - const data = await res.json() as ApiResponse; - const output = new URL(joinPath('output', data.output), root).toString(); + const data = await res.json() as LimitedResponse; + const url = new URL(joinPath('output', data.output), root).toString(); return { - output, + output: { + key: data.output, + url, + }, params: data.params, }; } else { @@ -229,5 +239,12 @@ export function makeClient(root: string, f = fetch): ApiClient { async outpaint() { throw new NotImplementedError(); }, + async ready(params: ApiResponse): Promise<{ready: boolean}> { + const path = new URL('ready', root); + path.searchParams.append('output', params.output.key); + + const res = await f(path); + return await res.json() as {ready: boolean}; + } }; } diff --git a/gui/src/components/ImageCard.tsx b/gui/src/components/ImageCard.tsx index abfeea73..8a7cd665 100644 --- a/gui/src/components/ImageCard.tsx +++ b/gui/src/components/ImageCard.tsx @@ -28,13 +28,13 @@ export function ImageCard(props: ImageCardProps) { } function downloadImage() { - window.open(output, '_blank'); + window.open(output.url, '_blank'); } return diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index ba006127..3c295e4b 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -1,4 +1,4 @@ -import { mustExist } from '@apextoaster/js-utils'; +import { doesExist, mustExist } from '@apextoaster/js-utils'; import { Grid } from '@mui/material'; import { useContext } from 'react'; import * as React from 'react'; @@ -17,14 +17,14 @@ export function ImageHistory() { const children = []; - if (loading) { - children.push(); // TODO: get dimensions from config + if (doesExist(loading)) { + children.push(); } if (history.length > 0) { - children.push(...history.map((item) => )); + children.push(...history.map((item) => )); } else { - if (loading === false) { + if (doesExist(loading) === false) { children.push(
No results. Press Generate.
); } } diff --git a/gui/src/components/Img2Img.tsx b/gui/src/components/Img2Img.tsx index 8e5c9018..123d6401 100644 --- a/gui/src/components/Img2Img.tsx +++ b/gui/src/components/Img2Img.tsx @@ -1,7 +1,7 @@ import { mustExist } from '@apextoaster/js-utils'; import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; -import { useMutation } from 'react-query'; +import { useMutation, useQueryClient } from 'react-query'; import { useStore } from 'zustand'; import { ConfigParams, IMAGE_FILTER } from '../config.js'; @@ -23,8 +23,6 @@ export function Img2Img(props: Img2ImgProps) { const { config, model, platform } = props; async function uploadSource() { - setLoading(true); - const output = await client.img2img({ ...params, model, @@ -32,12 +30,14 @@ export function Img2Img(props: Img2ImgProps) { source: mustExist(source), // TODO: show an error if this doesn't exist }); - pushHistory(output); - setLoading(false); + setLoading(output); } const client = mustExist(useContext(ClientContext)); - const upload = useMutation(uploadSource); + const query = useQueryClient(); + const upload = useMutation(uploadSource, { + onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}), + }); const state = mustExist(useContext(StateContext)); const params = useStore(state, (s) => s.img2img); @@ -45,8 +45,6 @@ export function Img2Img(props: Img2ImgProps) { const setImg2Img = useStore(state, (s) => s.setImg2Img); // eslint-disable-next-line @typescript-eslint/unbound-method const setLoading = useStore(state, (s) => s.setLoading); - // eslint-disable-next-line @typescript-eslint/unbound-method - const pushHistory = useStore(state, (s) => s.pushHistory); const [source, setSource] = useState(); diff --git a/gui/src/components/Inpaint.tsx b/gui/src/components/Inpaint.tsx index 69137436..60580470 100644 --- a/gui/src/components/Inpaint.tsx +++ b/gui/src/components/Inpaint.tsx @@ -2,7 +2,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils'; import { FormatColorFill, Gradient } from '@mui/icons-material'; import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; -import { useMutation } from 'react-query'; +import { useMutation, useQueryClient } from 'react-query'; import { useStore } from 'zustand'; import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js'; @@ -69,7 +69,6 @@ export function Inpaint(props: InpaintProps) { async function uploadSource() { const canvas = mustExist(canvasRef.current); - setLoading(true); return new Promise((res, rej) => { canvas.toBlob((blob) => { client.inpaint({ @@ -79,8 +78,7 @@ export function Inpaint(props: InpaintProps) { mask: mustExist(blob), source: mustExist(source), }).then((output) => { - pushHistory(output); - setLoading(false); + setLoading(output); res(); }).catch((err) => rej(err)); }); @@ -146,7 +144,10 @@ export function Inpaint(props: InpaintProps) { // eslint-disable-next-line @typescript-eslint/unbound-method const pushHistory = useStore(state, (s) => s.pushHistory); - const upload = useMutation(uploadSource); + const query = useQueryClient(); + const upload = useMutation(uploadSource, { + onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}), + }); // eslint-disable-next-line no-null/no-null const canvasRef = useRef(null); diff --git a/gui/src/components/LoadingCard.tsx b/gui/src/components/LoadingCard.tsx index 8c3b49ae..5cc21114 100644 --- a/gui/src/components/LoadingCard.tsx +++ b/gui/src/components/LoadingCard.tsx @@ -1,15 +1,37 @@ +import { mustExist } from '@apextoaster/js-utils'; import { Card, CardContent, CircularProgress } from '@mui/material'; import * as React from 'react'; +import { useContext } from 'react'; +import { useQuery } from 'react-query'; +import { useStore } from 'zustand'; + +import { ApiResponse } from '../api/client.js'; +import { POLL_TIME } from '../config.js'; +import { ClientContext, StateContext } from '../state.js'; export interface LoadingCardProps { - height: number; - width: number; + loading: ApiResponse; } export function LoadingCard(props: LoadingCardProps) { - return - -
+ const client = mustExist(React.useContext(ClientContext)); + + // eslint-disable-next-line @typescript-eslint/unbound-method + const pushHistory = useStore(mustExist(useContext(StateContext)), (state) => state.pushHistory); + + const ready = useQuery('ready', () => client.ready(props.loading), { + refetchInterval: POLL_TIME, + }); + + React.useEffect(() => { + if (ready.status === 'success' && ready.data.ready) { + pushHistory(props.loading); + } + }, [ready.status, ready.data?.ready]); + + return + +
diff --git a/gui/src/components/Txt2Img.tsx b/gui/src/components/Txt2Img.tsx index 1a6b7e67..69c1ca6c 100644 --- a/gui/src/components/Txt2Img.tsx +++ b/gui/src/components/Txt2Img.tsx @@ -1,7 +1,7 @@ import { mustExist } from '@apextoaster/js-utils'; import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; -import { useMutation } from 'react-query'; +import { useMutation, useQueryClient } from 'react-query'; import { useStore } from 'zustand'; import { ConfigParams } from '../config.js'; @@ -22,20 +22,20 @@ export function Txt2Img(props: Txt2ImgProps) { const { config, model, platform } = props; async function generateImage() { - setLoading(true); - const output = await client.txt2img({ ...params, model, platform, }); - pushHistory(output); - setLoading(false); + setLoading(output); } const client = mustExist(useContext(ClientContext)); - const generate = useMutation(generateImage); + const query = useQueryClient(); + const generate = useMutation(generateImage, { + onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}), + }); const state = mustExist(useContext(StateContext)); const params = useStore(state, (s) => s.txt2img); diff --git a/gui/src/config.ts b/gui/src/config.ts index a6d11f7f..7b7811af 100644 --- a/gui/src/config.ts +++ b/gui/src/config.ts @@ -43,7 +43,8 @@ export const DEFAULT_BRUSH = { size: 8, }; export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png'; -export const STALE_TIME = 3_000; +export const STALE_TIME = 300_000; // 5 minutes +export const POLL_TIME = 5_000; // 5 seconds export async function loadConfig(): Promise { const configPath = new URL('./config.json', window.origin); diff --git a/gui/src/main.tsx b/gui/src/main.tsx index 5a55b88a..e503e6a1 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -41,12 +41,8 @@ export async function main() { ...createDefaultSlice(...slice), }), { name: 'onnx-web', - partialize: (oldState) => ({ - ...oldState, - loading: false, - }), storage: createJSONStorage(() => localStorage), - version: 2, + version: 3, })); // prep react-query client diff --git a/gui/src/state.ts b/gui/src/state.ts index 8a744c73..d932eeab 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -39,12 +39,12 @@ interface InpaintSlice { interface HistorySlice { history: Array; limit: number; - loading: boolean; + loading: Maybe; pushHistory(image: ApiResponse): void; removeHistory(image: ApiResponse): void; setLimit(limit: number): void; - setLoading(loading: boolean): void; + setLoading(image: Maybe): void; } interface DefaultSlice { @@ -130,7 +130,8 @@ export function createStateSlices(base: ConfigParams) { const createHistorySlice: StateCreator = (set) => ({ history: [], limit: 4, - loading: false, + // eslint-disable-next-line no-null/no-null + loading: null, pushHistory(image) { set((prev) => ({ ...prev, @@ -138,6 +139,8 @@ export function createStateSlices(base: ConfigParams) { image, ...prev.history, ], + // eslint-disable-next-line no-null/no-null + loading: null, })); }, removeHistory(image) {