1
0
Fork 0

feat(gui): add retry function to error card

This commit is contained in:
Sean Sube 2023-03-18 18:22:41 -05:00
parent 6226778cfb
commit 89790645cb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
14 changed files with 182 additions and 55 deletions

View File

@ -1,5 +1,5 @@
/* eslint-disable max-lines */
import { doesExist } from '@apextoaster/js-utils';
import { doesExist, InvalidArgumentError } from '@apextoaster/js-utils';
import { ServerParams } from '../config.js';
import { range } from '../utils.js';
@ -191,6 +191,43 @@ export interface ModelsResponse {
upscaling: Array<string>;
}
export type RetryParams = {
type: 'txt2img';
model: ModelParams;
params: Txt2ImgParams;
upscale?: UpscaleParams;
} | {
type: 'img2img';
model: ModelParams;
params: Img2ImgParams;
upscale?: UpscaleParams;
} | {
type: 'inpaint';
model: ModelParams;
params: InpaintParams;
upscale?: UpscaleParams;
} | {
type: 'outpaint';
model: ModelParams;
params: OutpaintParams;
upscale?: UpscaleParams;
} | {
type: 'upscale';
model: ModelParams;
params: UpscaleReqParams;
upscale?: UpscaleParams;
} | {
type: 'blend';
model: ModelParams;
params: BlendParams;
upscale?: UpscaleParams;
};
export interface ImageResponseWithRetry {
image: ImageResponse;
retry: RetryParams;
}
export interface ApiClient {
/**
* List the available filter masks for inpaint.
@ -232,32 +269,32 @@ export interface ApiClient {
/**
* Start a txt2img pipeline.
*/
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/**
* Start an im2img pipeline.
*/
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/**
* Start an inpaint pipeline.
*/
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/**
* Start an outpaint pipeline.
*/
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/**
* Start an upscale pipeline.
*/
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponse>;
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/**
* Start a blending pipeline.
*/
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponse>;
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/**
* Check whether some pipeline's output is ready yet.
@ -265,6 +302,8 @@ export interface ApiClient {
ready(key: string): Promise<ReadyResponse>;
cancel(key: string): Promise<boolean>;
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
}
/**
@ -363,7 +402,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
* Make an API client using the given API root and fetch client.
*/
export function makeClient(root: string, f = fetch): ApiClient {
function throttleRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
function parseRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
return f(url, options).then((res) => parseApiResponse(root, res));
}
@ -407,7 +446,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
translation: Record<string, string>;
}>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);
@ -420,13 +459,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
const body = new FormData();
body.append('source', params.source, 'source');
// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'img2img',
model,
params,
upscale,
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
};
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);
@ -442,12 +489,20 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}
// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
method: 'POST',
});
return {
image,
retry: {
type: 'txt2img',
model,
params,
upscale,
},
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams) {
};
},
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
@ -464,13 +519,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');
// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'inpaint',
model,
params,
upscale,
},
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams) {
};
},
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
@ -504,13 +567,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');
// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'outpaint',
model,
params,
upscale,
},
async upscale(model: ModelParams, params: UpscaleReqParams, upscale: UpscaleParams): Promise<ImageResponse> {
};
},
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeApiUrl(root, 'upscale');
appendModelToURL(url, model);
@ -527,13 +598,21 @@ export function makeClient(root: string, f = fetch): ApiClient {
const body = new FormData();
body.append('source', params.source, 'source');
// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'upscale',
model,
params,
upscale,
},
async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise<ImageResponse> {
};
},
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
const url = makeApiUrl(root, 'blend');
appendModelToURL(url, model);
@ -549,11 +628,19 @@ export function makeClient(root: string, f = fetch): ApiClient {
body.append(name, params.sources[i], name);
}
// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
const image = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
retry: {
type: 'blend',
model,
params,
upscale,
}
};
},
async ready(key: string): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');
@ -571,6 +658,24 @@ export function makeClient(root: string, f = fetch): ApiClient {
});
return res.status === STATUS_SUCCESS;
},
async retry(retry: RetryParams): Promise<ImageResponseWithRetry> {
switch (retry.type) {
case 'blend':
return this.blend(retry.model, retry.params, retry.upscale);
case 'img2img':
return this.img2img(retry.model, retry.params, retry.upscale);
case 'inpaint':
return this.inpaint(retry.model, retry.params, retry.upscale);
case 'outpaint':
return this.outpaint(retry.model, retry.params, retry.upscale);
case 'txt2img':
return this.txt2img(retry.model, retry.params, retry.upscale);
case 'upscale':
return this.upscale(retry.model, retry.params, retry.upscale);
default:
throw new InvalidArgumentError('unknown request type');
}
}
};
}

View File

@ -44,6 +44,9 @@ export const LOCAL_CLIENT = {
async cancel(key) {
throw new NoServerError();
},
async retry(params) {
throw new NoServerError();
},
async models() {
throw new NoServerError();
},

View File

@ -30,7 +30,7 @@ export function ImageHistory() {
if (doesExist(item.ready) && item.ready.ready) {
if (item.ready.cancelled || item.ready.failed) {
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} />]);
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} retry={item.retry} />]);
continue;
}

View File

@ -1,5 +1,6 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Card, CardContent, Typography } from '@mui/material';
import { Delete, Replay } from '@mui/icons-material';
import { Box, Card, CardContent, IconButton, Tooltip, Typography } from '@mui/material';
import { Stack } from '@mui/system';
import * as React from 'react';
import { useContext } from 'react';
@ -7,31 +8,35 @@ import { useTranslation } from 'react-i18next';
import { useMutation } from 'react-query';
import { useStore } from 'zustand';
import { ImageResponse, ReadyResponse } from '../../client/api.js';
import { ImageResponse, ReadyResponse, RetryParams } from '../../client/api.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js';
export interface ErrorCardProps {
image: ImageResponse;
ready: ReadyResponse;
retry: RetryParams;
}
export function ErrorCard(props: ErrorCardProps) {
const { image, ready } = props;
const { image, ready, retry: retryParams } = props;
const client = mustExist(React.useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeHistory = useStore(state, (s) => s.removeHistory);
const { t } = useTranslation();
// TODO: actually retry
const retry = useMutation(() => {
// eslint-disable-next-line no-console
console.log('retry', image);
return Promise.resolve(true);
});
async function retryImage() {
removeHistory(image);
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams);
pushHistory(nextImage, nextRetry);
}
const retry = useMutation(retryImage);
return <Card sx={{ maxWidth: params.width.default }}>
<CardContent sx={{ height: params.height.default }}>
@ -50,8 +55,18 @@ export function ErrorCard(props: ErrorCardProps) {
current: ready.progress,
total: image.params.steps,
})}</Typography>
<Button onClick={() => retry.mutate()}>{t('loading.retry')}</Button>
<Button onClick={() => removeHistory(image)}>{t('loading.remove')}</Button>
<Stack direction='row' spacing={2}>
<Tooltip title={t('tooltip.retry')}>
<IconButton onClick={() => retry.mutate()}>
<Replay />
</IconButton>
</Tooltip>
<Tooltip title={t('tooltip.delete')}>
<IconButton onClick={() => removeHistory(image)}>
<Delete />
</IconButton>
</Tooltip>
</Stack>
</Stack>
</Box>
</CardContent>

View File

@ -16,14 +16,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
export function Blend() {
async function uploadSource() {
const { model, blend, upscale } = state.getState();
const output = await client.blend(model, {
const { image, retry } = await client.blend(model, {
...blend,
mask: mustExist(blend.mask),
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
}, upscale);
pushHistory(output);
pushHistory(image, retry);
}
const client = mustExist(useContext(ClientContext));

View File

@ -18,13 +18,12 @@ export function Img2Img() {
async function uploadSource() {
const { model, img2img, upscale } = state.getState();
const output = await client.img2img(model, {
const { image, retry } = await client.img2img(model, {
...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, upscale);
pushHistory(output);
pushHistory(image, retry);
}
const client = mustExist(useContext(ClientContext));

View File

@ -32,22 +32,22 @@ export function Inpaint() {
const { model, inpaint, outpaint, upscale } = state.getState();
if (outpaint.enabled) {
const output = await client.outpaint(model, {
const { image, retry } = await client.outpaint(model, {
...inpaint,
...outpaint,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);
pushHistory(output);
pushHistory(image, retry);
} else {
const output = await client.inpaint(model, {
const { image, retry } = await client.inpaint(model, {
...inpaint,
mask: mustExist(mask),
source: mustExist(source),
}, upscale);
pushHistory(output);
pushHistory(image, retry);
}
}

View File

@ -16,9 +16,9 @@ export function Txt2Img() {
async function generateImage() {
const { model, txt2img, upscale } = state.getState();
const output = await client.txt2img(model, txt2img, upscale);
const { image, retry } = await client.txt2img(model, txt2img, upscale);
pushHistory(output);
pushHistory(image, retry);
}
const client = mustExist(useContext(ClientContext));

View File

@ -15,13 +15,12 @@ import { PromptInput } from '../input/PromptInput.js';
export function Upscale() {
async function uploadSource() {
const { model, upscale } = state.getState();
const output = await client.upscale(model, {
const { image, retry } = await client.upscale(model, {
...params,
source: mustExist(params.source), // TODO: show an error if this doesn't exist
}, upscale);
pushHistory(output);
pushHistory(image, retry);
}
const client = mustExist(useContext(ClientContext));

View File

@ -16,6 +16,7 @@ import {
ModelParams,
OutpaintPixels,
ReadyResponse,
RetryParams,
Txt2ImgParams,
UpscaleParams,
UpscaleReqParams,
@ -30,6 +31,7 @@ type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Requir
interface HistoryItem {
image: ImageResponse;
ready: Maybe<ReadyResponse>;
retry: RetryParams;
}
interface BrushSlice {
@ -48,7 +50,7 @@ interface HistorySlice {
history: Array<HistoryItem>;
limit: number;
pushHistory(image: ImageResponse): void;
pushHistory(image: ImageResponse, retry: RetryParams): void;
removeHistory(image: ImageResponse): void;
setLimit(limit: number): void;
setReady(image: ImageResponse, ready: ReadyResponse): void;
@ -301,7 +303,7 @@ export function createStateSlices(server: ServerParams) {
const createHistorySlice: Slice<HistorySlice> = (set) => ({
history: [],
limit: DEFAULT_HISTORY.limit,
pushHistory(image) {
pushHistory(image, retry) {
set((prev) => ({
...prev,
history: [
@ -313,6 +315,7 @@ export function createStateSlices(server: ServerParams) {
progress: 0,
ready: false,
},
retry,
},
...prev.history,
],

View File

@ -154,6 +154,7 @@ export const I18N_STRINGS_DE = {
delete: 'Löschen',
next: 'Nächste',
previous: 'Vorherige',
retry: '',
save: 'Speichern',
},
upscaleOrder: {

View File

@ -217,6 +217,7 @@ export const I18N_STRINGS_EN = {
delete: 'Delete',
next: 'Next',
previous: 'Previous',
retry: 'Retry',
save: 'Save',
},
upscaleOrder: {

View File

@ -154,6 +154,7 @@ export const I18N_STRINGS_ES = {
delete: 'Borrar',
next: 'Próximo',
previous: 'Anterior',
retry: '',
save: 'Ahorrar',
},
upscaleOrder: {

View File

@ -154,6 +154,7 @@ export const I18N_STRINGS_FR = {
delete: '',
next: '',
previous: '',
retry: '',
save: '',
},
upscaleOrder: {