From 89790645cbaf74b7cad2d7e7fd693665db13caa4 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 18:22:41 -0500 Subject: [PATCH] feat(gui): add retry function to error card --- gui/src/client/api.ts | 157 +++++++++++++++++++++----- gui/src/client/local.ts | 3 + gui/src/components/ImageHistory.tsx | 2 +- gui/src/components/card/RetryCard.tsx | 37 ++++-- gui/src/components/tab/Blend.tsx | 5 +- gui/src/components/tab/Img2Img.tsx | 5 +- gui/src/components/tab/Inpaint.tsx | 8 +- gui/src/components/tab/Txt2Img.tsx | 4 +- gui/src/components/tab/Upscale.tsx | 5 +- gui/src/state.ts | 7 +- gui/src/strings/de.ts | 1 + gui/src/strings/en.ts | 1 + gui/src/strings/es.ts | 1 + gui/src/strings/fr.ts | 1 + 14 files changed, 182 insertions(+), 55 deletions(-) diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 2f01b241..087f54b6 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -1,5 +1,5 @@ /* eslint-disable max-lines */ -import { doesExist } from '@apextoaster/js-utils'; +import { doesExist, InvalidArgumentError } from '@apextoaster/js-utils'; import { ServerParams } from '../config.js'; import { range } from '../utils.js'; @@ -191,6 +191,43 @@ export interface ModelsResponse { upscaling: Array; } +export type RetryParams = { + type: 'txt2img'; + model: ModelParams; + params: Txt2ImgParams; + upscale?: UpscaleParams; +} | { + type: 'img2img'; + model: ModelParams; + params: Img2ImgParams; + upscale?: UpscaleParams; +} | { + type: 'inpaint'; + model: ModelParams; + params: InpaintParams; + upscale?: UpscaleParams; +} | { + type: 'outpaint'; + model: ModelParams; + params: OutpaintParams; + upscale?: UpscaleParams; +} | { + type: 'upscale'; + model: ModelParams; + params: UpscaleReqParams; + upscale?: UpscaleParams; +} | { + type: 'blend'; + model: ModelParams; + params: BlendParams; + upscale?: UpscaleParams; +}; + +export interface ImageResponseWithRetry { + image: ImageResponse; + retry: RetryParams; +} + export interface ApiClient { /** * List the available filter masks for inpaint. @@ -232,32 +269,32 @@ export interface ApiClient { /** * Start a txt2img pipeline. */ - txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise; + txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise; /** * Start an im2img pipeline. */ - img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise; + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise; /** * Start an inpaint pipeline. */ - inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise; + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise; /** * Start an outpaint pipeline. */ - outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise; + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise; /** * Start an upscale pipeline. */ - upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise; + upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise; /** * Start a blending pipeline. */ - blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; + blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; /** * Check whether some pipeline's output is ready yet. @@ -265,6 +302,8 @@ export interface ApiClient { ready(key: string): Promise; cancel(key: string): Promise; + + retry(params: RetryParams): Promise; } /** @@ -363,7 +402,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { * Make an API client using the given API root and fetch client. */ export function makeClient(root: string, f = fetch): ApiClient { - function throttleRequest(url: URL, options: RequestInit): Promise { + function parseRequest(url: URL, options: RequestInit): Promise { return f(url, options).then((res) => parseApiResponse(root, res)); } @@ -407,7 +446,7 @@ export function makeClient(root: string, f = fetch): ApiClient { translation: Record; }>; }, - async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise { + async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise { const url = makeImageURL(root, 'img2img', params); appendModelToURL(url, model); @@ -420,13 +459,21 @@ export function makeClient(root: string, f = fetch): ApiClient { const body = new FormData(); body.append('source', params.source, 'source'); - // eslint-disable-next-line no-return-await - return await throttleRequest(url, { + const image = await parseRequest(url, { body, method: 'POST', }); + return { + image, + retry: { + type: 'img2img', + model, + params, + upscale, + }, + }; }, - async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise { + async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise { const url = makeImageURL(root, 'txt2img', params); appendModelToURL(url, model); @@ -442,12 +489,20 @@ export function makeClient(root: string, f = fetch): ApiClient { appendUpscaleToURL(url, upscale); } - // eslint-disable-next-line no-return-await - return await throttleRequest(url, { + const image = await parseRequest(url, { method: 'POST', }); + return { + image, + retry: { + type: 'txt2img', + model, + params, + upscale, + }, + }; }, - async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) { + async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise { const url = makeImageURL(root, 'inpaint', params); appendModelToURL(url, model); @@ -464,13 +519,21 @@ export function makeClient(root: string, f = fetch): ApiClient { body.append('mask', params.mask, 'mask'); body.append('source', params.source, 'source'); - // eslint-disable-next-line no-return-await - return await throttleRequest(url, { + const image = await parseRequest(url, { body, method: 'POST', }); + return { + image, + retry: { + type: 'inpaint', + model, + params, + upscale, + }, + }; }, - async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) { + async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise { const url = makeImageURL(root, 'inpaint', params); appendModelToURL(url, model); @@ -504,13 +567,21 @@ export function makeClient(root: string, f = fetch): ApiClient { body.append('mask', params.mask, 'mask'); body.append('source', params.source, 'source'); - // eslint-disable-next-line no-return-await - return await throttleRequest(url, { + const image = await parseRequest(url, { body, method: 'POST', }); + return { + image, + retry: { + type: 'outpaint', + model, + params, + upscale, + }, + }; }, - async upscale(model: ModelParams, params: UpscaleReqParams, upscale: UpscaleParams): Promise { + async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise { const url = makeApiUrl(root, 'upscale'); appendModelToURL(url, model); @@ -527,13 +598,21 @@ export function makeClient(root: string, f = fetch): ApiClient { const body = new FormData(); body.append('source', params.source, 'source'); - // eslint-disable-next-line no-return-await - return await throttleRequest(url, { + const image = await parseRequest(url, { body, method: 'POST', }); + return { + image, + retry: { + type: 'upscale', + model, + params, + upscale, + }, + }; }, - async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise { + async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise { const url = makeApiUrl(root, 'blend'); appendModelToURL(url, model); @@ -549,11 +628,19 @@ export function makeClient(root: string, f = fetch): ApiClient { body.append(name, params.sources[i], name); } - // eslint-disable-next-line no-return-await - return await throttleRequest(url, { + const image = await parseRequest(url, { body, method: 'POST', }); + return { + image, + retry: { + type: 'blend', + model, + params, + upscale, + } + }; }, async ready(key: string): Promise { const path = makeApiUrl(root, 'ready'); @@ -571,6 +658,24 @@ export function makeClient(root: string, f = fetch): ApiClient { }); return res.status === STATUS_SUCCESS; }, + async retry(retry: RetryParams): Promise { + switch (retry.type) { + case 'blend': + return this.blend(retry.model, retry.params, retry.upscale); + case 'img2img': + return this.img2img(retry.model, retry.params, retry.upscale); + case 'inpaint': + return this.inpaint(retry.model, retry.params, retry.upscale); + case 'outpaint': + return this.outpaint(retry.model, retry.params, retry.upscale); + case 'txt2img': + return this.txt2img(retry.model, retry.params, retry.upscale); + case 'upscale': + return this.upscale(retry.model, retry.params, retry.upscale); + default: + throw new InvalidArgumentError('unknown request type'); + } + } }; } diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index bd560713..8dc931af 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -44,6 +44,9 @@ export const LOCAL_CLIENT = { async cancel(key) { throw new NoServerError(); }, + async retry(params) { + throw new NoServerError(); + }, async models() { throw new NoServerError(); }, diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index fad4d2b3..98b111fd 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -30,7 +30,7 @@ export function ImageHistory() { if (doesExist(item.ready) && item.ready.ready) { if (item.ready.cancelled || item.ready.failed) { - children.push([key, ]); + children.push([key, ]); continue; } diff --git a/gui/src/components/card/RetryCard.tsx b/gui/src/components/card/RetryCard.tsx index 501ce968..d47f4111 100644 --- a/gui/src/components/card/RetryCard.tsx +++ b/gui/src/components/card/RetryCard.tsx @@ -1,5 +1,6 @@ import { mustExist } from '@apextoaster/js-utils'; -import { Box, Button, Card, CardContent, Typography } from '@mui/material'; +import { Delete, Replay } from '@mui/icons-material'; +import { Box, Card, CardContent, IconButton, Tooltip, Typography } from '@mui/material'; import { Stack } from '@mui/system'; import * as React from 'react'; import { useContext } from 'react'; @@ -7,31 +8,35 @@ import { useTranslation } from 'react-i18next'; import { useMutation } from 'react-query'; import { useStore } from 'zustand'; -import { ImageResponse, ReadyResponse } from '../../client/api.js'; +import { ImageResponse, ReadyResponse, RetryParams } from '../../client/api.js'; import { ClientContext, ConfigContext, StateContext } from '../../state.js'; export interface ErrorCardProps { image: ImageResponse; ready: ReadyResponse; + retry: RetryParams; } export function ErrorCard(props: ErrorCardProps) { - const { image, ready } = props; + const { image, ready, retry: retryParams } = props; const client = mustExist(React.useContext(ClientContext)); const { params } = mustExist(useContext(ConfigContext)); const state = mustExist(useContext(StateContext)); // eslint-disable-next-line @typescript-eslint/unbound-method + const pushHistory = useStore(state, (s) => s.pushHistory); + // eslint-disable-next-line @typescript-eslint/unbound-method const removeHistory = useStore(state, (s) => s.removeHistory); const { t } = useTranslation(); - // TODO: actually retry - const retry = useMutation(() => { - // eslint-disable-next-line no-console - console.log('retry', image); - return Promise.resolve(true); - }); + async function retryImage() { + removeHistory(image); + const { image: nextImage, retry: nextRetry } = await client.retry(retryParams); + pushHistory(nextImage, nextRetry); + } + + const retry = useMutation(retryImage); return @@ -50,8 +55,18 @@ export function ErrorCard(props: ErrorCardProps) { current: ready.progress, total: image.params.steps, })} - - + + + retry.mutate()}> + + + + + removeHistory(image)}> + + + + diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index cbeb7d85..943a1594 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -16,14 +16,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js'; export function Blend() { async function uploadSource() { const { model, blend, upscale } = state.getState(); - - const output = await client.blend(model, { + const { image, retry } = await client.blend(model, { ...blend, mask: mustExist(blend.mask), sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist }, upscale); - pushHistory(output); + pushHistory(image, retry); } const client = mustExist(useContext(ClientContext)); diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index 1d89973e..5fccc22f 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -18,13 +18,12 @@ export function Img2Img() { async function uploadSource() { const { model, img2img, upscale } = state.getState(); - - const output = await client.img2img(model, { + const { image, retry } = await client.img2img(model, { ...img2img, source: mustExist(img2img.source), // TODO: show an error if this doesn't exist }, upscale); - pushHistory(output); + pushHistory(image, retry); } const client = mustExist(useContext(ClientContext)); diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index 20c935ab..17c35229 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -32,22 +32,22 @@ export function Inpaint() { const { model, inpaint, outpaint, upscale } = state.getState(); if (outpaint.enabled) { - const output = await client.outpaint(model, { + const { image, retry } = await client.outpaint(model, { ...inpaint, ...outpaint, mask: mustExist(mask), source: mustExist(source), }, upscale); - pushHistory(output); + pushHistory(image, retry); } else { - const output = await client.inpaint(model, { + const { image, retry } = await client.inpaint(model, { ...inpaint, mask: mustExist(mask), source: mustExist(source), }, upscale); - pushHistory(output); + pushHistory(image, retry); } } diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 4990f222..9794edc7 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -16,9 +16,9 @@ export function Txt2Img() { async function generateImage() { const { model, txt2img, upscale } = state.getState(); - const output = await client.txt2img(model, txt2img, upscale); + const { image, retry } = await client.txt2img(model, txt2img, upscale); - pushHistory(output); + pushHistory(image, retry); } const client = mustExist(useContext(ClientContext)); diff --git a/gui/src/components/tab/Upscale.tsx b/gui/src/components/tab/Upscale.tsx index 9faef6b6..ede4c32f 100644 --- a/gui/src/components/tab/Upscale.tsx +++ b/gui/src/components/tab/Upscale.tsx @@ -15,13 +15,12 @@ import { PromptInput } from '../input/PromptInput.js'; export function Upscale() { async function uploadSource() { const { model, upscale } = state.getState(); - - const output = await client.upscale(model, { + const { image, retry } = await client.upscale(model, { ...params, source: mustExist(params.source), // TODO: show an error if this doesn't exist }, upscale); - pushHistory(output); + pushHistory(image, retry); } const client = mustExist(useContext(ClientContext)); diff --git a/gui/src/state.ts b/gui/src/state.ts index 2a73670b..39f9123b 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -16,6 +16,7 @@ import { ModelParams, OutpaintPixels, ReadyResponse, + RetryParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams, @@ -30,6 +31,7 @@ type TabState = ConfigFiles> & ConfigState; + retry: RetryParams; } interface BrushSlice { @@ -48,7 +50,7 @@ interface HistorySlice { history: Array; limit: number; - pushHistory(image: ImageResponse): void; + pushHistory(image: ImageResponse, retry: RetryParams): void; removeHistory(image: ImageResponse): void; setLimit(limit: number): void; setReady(image: ImageResponse, ready: ReadyResponse): void; @@ -301,7 +303,7 @@ export function createStateSlices(server: ServerParams) { const createHistorySlice: Slice = (set) => ({ history: [], limit: DEFAULT_HISTORY.limit, - pushHistory(image) { + pushHistory(image, retry) { set((prev) => ({ ...prev, history: [ @@ -313,6 +315,7 @@ export function createStateSlices(server: ServerParams) { progress: 0, ready: false, }, + retry, }, ...prev.history, ], diff --git a/gui/src/strings/de.ts b/gui/src/strings/de.ts index a1e68bbd..b24330c5 100644 --- a/gui/src/strings/de.ts +++ b/gui/src/strings/de.ts @@ -154,6 +154,7 @@ export const I18N_STRINGS_DE = { delete: 'Löschen', next: 'Nächste', previous: 'Vorherige', + retry: '', save: 'Speichern', }, upscaleOrder: { diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index ab1751a6..d79efb7a 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -217,6 +217,7 @@ export const I18N_STRINGS_EN = { delete: 'Delete', next: 'Next', previous: 'Previous', + retry: 'Retry', save: 'Save', }, upscaleOrder: { diff --git a/gui/src/strings/es.ts b/gui/src/strings/es.ts index 23109260..e0b6a47e 100644 --- a/gui/src/strings/es.ts +++ b/gui/src/strings/es.ts @@ -154,6 +154,7 @@ export const I18N_STRINGS_ES = { delete: 'Borrar', next: 'Próximo', previous: 'Anterior', + retry: '', save: 'Ahorrar', }, upscaleOrder: { diff --git a/gui/src/strings/fr.ts b/gui/src/strings/fr.ts index 652386bf..c4455be6 100644 --- a/gui/src/strings/fr.ts +++ b/gui/src/strings/fr.ts @@ -154,6 +154,7 @@ export const I18N_STRINGS_FR = { delete: '', next: '', previous: '', + retry: '', save: '', }, upscaleOrder: {