1
0
Fork 0

feat(gui): add blend tab and copy button

This commit is contained in:
Sean Sube 2023-02-12 16:52:50 -06:00
parent d6201c9d32
commit 4abbb00fd0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 175 additions and 12 deletions

View File

@ -129,6 +129,11 @@ export interface UpscaleReqParams {
source: Blob; source: Blob;
} }
export interface BlendParams {
sources: Array<Blob>;
mask: Blob;
}
/** /**
* General response for most image requests. * General response for most image requests.
*/ */
@ -217,6 +222,11 @@ export interface ApiClient {
*/ */
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponse>; upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponse>;
/**
* Start a blending pipeline.
*/
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponse>;
/** /**
* Check whether some pipeline's output is ready yet. * Check whether some pipeline's output is ready yet.
*/ */
@ -471,6 +481,9 @@ export function makeClient(root: string, f = fetch): ApiClient {
method: 'POST', method: 'POST',
}); });
}, },
async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise<ImageResponse> {
throw new Error('TODO');
},
async ready(params: ImageResponse): Promise<ReadyResponse> { async ready(params: ImageResponse): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready'); const path = makeApiUrl(root, 'ready');
path.searchParams.append('output', params.output.key); path.searchParams.append('output', params.output.key);

View File

@ -1,5 +1,5 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Brush, ContentCopy, Delete, Download } from '@mui/icons-material'; import { Blender, Brush, ContentCopy, CropFree, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material'; import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react'; import { useContext } from 'react';
@ -33,6 +33,10 @@ export function ImageCard(props: ImageCardProps) {
const setImg2Img = useStore(state, (s) => s.setImg2Img); const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint); const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setUpscaleTab);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBlend = useStore(state, (s) => s.setBlend);
async function loadSource() { async function loadSource() {
const req = await fetch(output.url); const req = await fetch(output.url);
@ -55,6 +59,23 @@ export function ImageCard(props: ImageCardProps) {
setHash('inpaint'); setHash('inpaint');
} }
async function copySourceToUpscale() {
const blob = await loadSource();
setUpscale({
source: blob,
});
setHash('upscale');
}
async function copySourceToBlend() {
const blob = await loadSource();
// TODO: push instead
setBlend({
sources: [blob],
});
setHash('blend');
}
function deleteImage() { function deleteImage() {
if (doesExist(props.onDelete)) { if (doesExist(props.onDelete)) {
props.onDelete(value); props.onDelete(value);
@ -86,28 +107,42 @@ export function ImageCard(props: ImageCardProps) {
<GridItem xs={12}> <GridItem xs={12}>
<Box textAlign='left'>{params.prompt}</Box> <Box textAlign='left'>{params.prompt}</Box>
</GridItem> </GridItem>
<GridItem xs={3}> <GridItem xs={2}>
<Tooltip title='Save'> <Tooltip title='Save'>
<IconButton onClick={downloadImage}> <IconButton onClick={downloadImage}>
<Download /> <Download />
</IconButton> </IconButton>
</Tooltip> </Tooltip>
</GridItem> </GridItem>
<GridItem xs={3}> <GridItem xs={2}>
<Tooltip title='Img2img'> <Tooltip title='Img2img'>
<IconButton onClick={copySourceToImg2Img}> <IconButton onClick={copySourceToImg2Img}>
<ContentCopy /> <ContentCopy />
</IconButton> </IconButton>
</Tooltip> </Tooltip>
</GridItem> </GridItem>
<GridItem xs={3}> <GridItem xs={2}>
<Tooltip title='Inpaint'> <Tooltip title='Inpaint'>
<IconButton onClick={copySourceToInpaint}> <IconButton onClick={copySourceToInpaint}>
<Brush /> <Brush />
</IconButton> </IconButton>
</Tooltip> </Tooltip>
</GridItem> </GridItem>
<GridItem xs={3}> <GridItem xs={2}>
<Tooltip title='Upscale'>
<IconButton onClick={copySourceToUpscale}>
<ZoomOutMap />
</IconButton>
</Tooltip>
</GridItem>
<GridItem xs={2}>
<Tooltip title='Blend'>
<IconButton onClick={copySourceToBlend}>
<Blender />
</IconButton>
</Tooltip>
</GridItem>
<GridItem xs={2}>
<Tooltip title='Delete'> <Tooltip title='Delete'>
<IconButton onClick={deleteImage}> <IconButton onClick={deleteImage}>
<Delete /> <Delete />

View File

@ -6,6 +6,7 @@ import { useHash } from 'react-use/lib/useHash';
import { ModelControl } from './control/ModelControl.js'; import { ModelControl } from './control/ModelControl.js';
import { ImageHistory } from './ImageHistory.js'; import { ImageHistory } from './ImageHistory.js';
import { Blend } from './tab/Blend.js';
import { Img2Img } from './tab/Img2Img.js'; import { Img2Img } from './tab/Img2Img.js';
import { Inpaint } from './tab/Inpaint.js'; import { Inpaint } from './tab/Inpaint.js';
import { Settings } from './tab/Settings.js'; import { Settings } from './tab/Settings.js';
@ -13,6 +14,14 @@ import { Txt2Img } from './tab/Txt2Img.js';
import { Upscale } from './tab/Upscale.js'; import { Upscale } from './tab/Upscale.js';
const REMOVE_HASH = /^#?(.*)$/; const REMOVE_HASH = /^#?(.*)$/;
const TAB_LABELS = [
'txt2img',
'img2img',
'inpaint',
'upscale',
'blend',
'settings',
];
export function OnnxWeb() { export function OnnxWeb() {
const [hash, setHash] = useHash(); const [hash, setHash] = useHash();
@ -26,7 +35,7 @@ export function OnnxWeb() {
} }
} }
return 'txt2img'; return TAB_LABELS[0];
} }
return ( return (
@ -44,11 +53,7 @@ export function OnnxWeb() {
<TabList onChange={(_e, idx) => { <TabList onChange={(_e, idx) => {
setHash(idx); setHash(idx);
}}> }}>
<Tab label='txt2img' value='txt2img' /> {TAB_LABELS.map((name) => <Tab key={name} label={name} value={name} />)}
<Tab label='img2img' value='img2img' />
<Tab label='inpaint' value='inpaint' />
<Tab label='upscale' value='upscale' />
<Tab label='settings' value='settings' />
</TabList> </TabList>
</Box> </Box>
<TabPanel value='txt2img'> <TabPanel value='txt2img'>
@ -63,6 +68,9 @@ export function OnnxWeb() {
<TabPanel value='upscale'> <TabPanel value='upscale'>
<Upscale /> <Upscale />
</TabPanel> </TabPanel>
<TabPanel value='blend'>
<Blend />
</TabPanel>
<TabPanel value='settings'> <TabPanel value='settings'>
<Settings /> <Settings />
</TabPanel> </TabPanel>

View File

@ -0,0 +1,70 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { IMAGE_FILTER } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { MaskCanvas } from '../input/MaskCanvas.js';
export function Blend() {
async function uploadSource() {
const { model, blend, upscale } = state.getState();
const output = await client.blend(model, {
...blend,
mask: mustExist(blend.mask),
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
}, upscale);
setLoading(output);
}
const client = mustExist(useContext(ClientContext));
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
});
const state = mustExist(useContext(StateContext));
const blend = useStore(state, (s) => s.blend);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBlend = useStore(state, (s) => s.setBlend);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.pushLoading);
const sources = mustDefault(blend.sources, []);
return <Box>
<Stack spacing={2}>
<ImageInput
filter={IMAGE_FILTER}
image={sources[0]}
hideSelection={true}
label='Source'
onChange={(file) => {
setBlend({
sources: [file],
});
}}
/>
<MaskCanvas
source={sources[0]}
mask={blend.mask}
onSave={() => {
// TODO
}}
/>
<UpscaleControl />
<Button
disabled={sources.length === 0}
variant='contained'
onClick={() => upload.mutate()}
>Generate</Button>
</Stack>
</Box>;
}

View File

@ -25,7 +25,7 @@ export type KeyFilter<T extends object, TValid = number | string> = {
* Keep fields with a file-like value, but make them optional. * Keep fields with a file-like value, but make them optional.
*/ */
export type ConfigFiles<T extends object> = { export type ConfigFiles<T extends object> = {
[K in KeyFilter<T, Blob | File>]: Maybe<T[K]>; [K in KeyFilter<T, Blob | File | Array<Blob | File>>]: Maybe<T[K]>;
}; };
/** /**

View File

@ -47,6 +47,7 @@ export async function main() {
createOutpaintSlice, createOutpaintSlice,
createTxt2ImgSlice, createTxt2ImgSlice,
createUpscaleSlice, createUpscaleSlice,
createBlendSlice,
createResetSlice, createResetSlice,
} = createStateSlices(params); } = createStateSlices(params);
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({ const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
@ -59,6 +60,7 @@ export async function main() {
...createTxt2ImgSlice(...slice), ...createTxt2ImgSlice(...slice),
...createOutpaintSlice(...slice), ...createOutpaintSlice(...slice),
...createUpscaleSlice(...slice), ...createUpscaleSlice(...slice),
...createBlendSlice(...slice),
...createResetSlice(...slice), ...createResetSlice(...slice),
}), { }), {
name: 'onnx-web', name: 'onnx-web',

View File

@ -1,3 +1,4 @@
/* eslint-disable max-lines */
/* eslint-disable no-null/no-null */ /* eslint-disable no-null/no-null */
import { doesExist, Maybe } from '@apextoaster/js-utils'; import { doesExist, Maybe } from '@apextoaster/js-utils';
import { createContext } from 'react'; import { createContext } from 'react';
@ -6,6 +7,7 @@ import { StateCreator, StoreApi } from 'zustand';
import { import {
ApiClient, ApiClient,
BaseImgParams, BaseImgParams,
BlendParams,
BrushParams, BrushParams,
ImageResponse, ImageResponse,
Img2ImgParams, Img2ImgParams,
@ -100,6 +102,13 @@ interface UpscaleSlice {
resetUpscaleTab(): void; resetUpscaleTab(): void;
} }
interface BlendSlice {
blend: TabState<BlendParams>;
setBlend(blend: Partial<BlendParams>): void;
resetBlend(): void;
}
interface ResetSlice { interface ResetSlice {
resetAll(): void; resetAll(): void;
} }
@ -118,6 +127,7 @@ export type OnnxState
& OutpaintSlice & OutpaintSlice
& Txt2ImgSlice & Txt2ImgSlice
& UpscaleSlice & UpscaleSlice
& BlendSlice
& ResetSlice; & ResetSlice;
/** /**
@ -419,6 +429,29 @@ export function createStateSlices(server: ServerParams) {
}, },
}); });
const createBlendSlice: Slice<BlendSlice> = (set) => ({
blend: {
mask: null,
sources: [],
},
setBlend(blend) {
set((prev) => ({
blend: {
...prev.blend,
...blend,
},
}));
},
resetBlend() {
set((prev) => ({
blend: {
mask: null,
sources: [],
},
}));
},
});
const createDefaultSlice: Slice<DefaultSlice> = (set) => ({ const createDefaultSlice: Slice<DefaultSlice> = (set) => ({
defaults: { defaults: {
...base, ...base,
@ -459,6 +492,7 @@ export function createStateSlices(server: ServerParams) {
next.resetInpaint(); next.resetInpaint();
next.resetTxt2Img(); next.resetTxt2Img();
next.resetUpscaleTab(); next.resetUpscaleTab();
next.resetBlend();
// TODO: reset more stuff // TODO: reset more stuff
return next; return next;
}); });
@ -475,6 +509,7 @@ export function createStateSlices(server: ServerParams) {
createOutpaintSlice, createOutpaintSlice,
createTxt2ImgSlice, createTxt2ImgSlice,
createUpscaleSlice, createUpscaleSlice,
createBlendSlice,
createResetSlice, createResetSlice,
}; };
} }