1
0
Fork 0

feat(gui): add outpainting to API client and state

This commit is contained in:
Sean Sube 2023-01-14 18:07:05 -06:00
parent 34fa3f6341
commit 6cd98bb960
8 changed files with 148 additions and 55 deletions

View File

@ -43,11 +43,21 @@ export type Txt2ImgResponse = Required<Txt2ImgParams>;
export interface InpaintParams extends BaseImgParams {
mask: Blob;
source: Blob;
}
left?: number;
right?: number;
top?: number;
bottom?: number;
export interface OutpaintPixels {
left: number;
right: number;
top: number;
bottom: number;
}
export type OutpaintParams = InpaintParams & OutpaintPixels;
export interface BrushParams {
color: number;
size: number;
strength: number;
}
export interface ApiResponse {
@ -71,6 +81,7 @@ export interface ApiClient {
img2img(params: Img2ImgParams): Promise<ApiResponse>;
txt2img(params: Txt2ImgParams): Promise<ApiResponse>;
inpaint(params: InpaintParams): Promise<ApiResponse>;
outpaint(params: OutpaintParams): Promise<ApiResponse>;
ready(params: ApiResponse): Promise<ApiReady>;
}
@ -211,11 +222,29 @@ export function makeClient(root: string, f = fetch): ApiClient {
const url = makeImageURL(root, 'inpaint', params);
const body = new FormData();
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');
pending = throttleRequest(url, {
body,
method: 'POST',
});
// eslint-disable-next-line no-return-await
return await pending;
},
async outpaint(params: OutpaintParams) {
if (doesExist(pending)) {
return pending;
}
const url = makeImageURL(root, 'inpaint', params);
if (doesExist(params.left)) {
url.searchParams.append('left', params.left.toFixed(0));
}
if (doesExist(params.right)) {
url.searchParams.append('right', params.right.toFixed(0));
}

View File

@ -1,4 +1,4 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation, useQueryClient } from 'react-query';
@ -25,15 +25,30 @@ export function Inpaint(props: InpaintProps) {
const client = mustExist(useContext(ClientContext));
async function uploadSource(): Promise<void> {
const output = await client.inpaint({
...params,
model,
platform,
mask: mustExist(params.mask),
source: mustExist(params.source),
});
const outpaint = state.getState().outpaint; // TODO: seems shady
setLoading(output);
if (outpaint.bottom > 0 || outpaint.left > 0 || outpaint.right > 0 || outpaint.top > 0) {
const output = await client.outpaint({
...params,
...outpaint,
model,
platform,
mask: mustExist(params.mask),
source: mustExist(params.source),
});
setLoading(output);
} else {
const output = await client.inpaint({
...params,
model,
platform,
mask: mustExist(params.mask),
source: mustExist(params.source),
});
setLoading(output);
}
}
const state = mustExist(useContext(StateContext));

View File

@ -2,9 +2,11 @@ import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Button, Stack } from '@mui/material';
import { throttle } from 'lodash';
import React, { RefObject, useEffect, useMemo, useRef, useState } from 'react';
import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react';
import { useStore } from 'zustand';
import { ConfigParams, DEFAULT_BRUSH, SAVE_TIME } from '../config.js';
import { ConfigParams, SAVE_TIME } from '../config.js';
import { StateContext } from '../state.js';
import { NumericField } from './NumericField';
export const FULL_CIRCLE = 2 * Math.PI;
@ -111,18 +113,20 @@ export function MaskCanvas(props: MaskCanvasProps) {
const maskState = useRef(MASK_STATE.clean);
const [background, setBackground] = useState<string>();
const [clicks, setClicks] = useState<Array<Point>>([]);
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
const [brushOpacity, setBrushOpacity] = useState(1.0);
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);
const state = mustExist(useContext(StateContext));
const brush = useStore(state, (s) => s.brush);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBrush = useStore(state, (s) => s.setBrush);
useEffect(() => {
// including clicks.length prevents the initial render from saving a blank canvas
if (doesExist(bufferRef.current) && maskState.current === MASK_STATE.painting && clicks.length > 0) {
const { ctx } = getContext(bufferRef);
ctx.fillStyle = grayToRGB(brushColor, brushOpacity);
ctx.fillStyle = grayToRGB(brush.color, brush.strength);
for (const click of clicks) {
drawCircle(ctx, click, brushSize);
drawCircle(ctx, click, brush.size);
}
clicks.length = 0;
@ -188,12 +192,12 @@ export function MaskCanvas(props: MaskCanvasProps) {
const bounds = canvas.getBoundingClientRect();
const { ctx } = getContext(bufferRef);
ctx.fillStyle = grayToRGB(brushColor, brushOpacity);
ctx.fillStyle = grayToRGB(brush.color, brush.strength);
drawCircle(ctx, {
x: event.clientX - bounds.left,
y: event.clientY - bounds.top,
}, brushSize);
}, brush.size);
maskState.current = MASK_STATE.dirty;
save();
@ -215,12 +219,12 @@ export function MaskCanvas(props: MaskCanvasProps) {
}]);
} else {
const { ctx } = getClearContext(brushRef);
ctx.fillStyle = grayToRGB(brushColor, brushOpacity);
ctx.fillStyle = grayToRGB(brush.color, brush.strength);
drawCircle(ctx, {
x: event.clientX - bounds.left,
y: event.clientY - bounds.top,
}, brushSize);
}, brush.size);
drawBuffer();
}
@ -228,13 +232,13 @@ export function MaskCanvas(props: MaskCanvasProps) {
/>
<Stack direction='row' spacing={4}>
<NumericField
label='Brush Shade'
label='Brush Color'
min={0}
max={255}
step={1}
value={brushColor}
onChange={(value) => {
setBrushColor(value);
value={brush.color}
onChange={(color) => {
setBrush({ color });
}}
/>
<NumericField
@ -242,9 +246,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
min={4}
max={64}
step={1}
value={brushSize}
onChange={(value) => {
setBrushSize(value);
value={brush.size}
onChange={(size) => {
setBrush({ size });
}}
/>
<NumericField
@ -253,9 +257,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
min={0}
max={1}
step={0.01}
value={brushOpacity}
onChange={(value) => {
setBrushOpacity(value);
value={brush.strength}
onChange={(strength) => {
setBrush({ strength });
}}
/>
<Button

View File

@ -1,4 +1,4 @@
import { Alert, AlertTitle, Box, Container, Stack, Typography } from '@mui/material';
import { Alert, AlertTitle, Box, Container, Link, Stack, Typography } from '@mui/material';
import * as React from 'react';
export interface OnnxErrorProps {
@ -12,7 +12,7 @@ export function OnnxError(props: OnnxErrorProps) {
<Container>
<Box sx={{ my: 4 }}>
<Typography variant='h3' gutterBottom>
<a href='https://github.com/ssube/onnx-web'>ONNX Web</a>
<Link href='https://github.com/ssube/onnx-web' target='_blank' underline='hover'>ONNX Web</Link>
</Typography>
</Box>
<Box sx={{ my: 4 }}>

View File

@ -1,6 +1,6 @@
import { mustExist } from '@apextoaster/js-utils';
import { TabContext, TabList, TabPanel } from '@mui/lab';
import { Box, Container, Divider, Stack, Tab, Typography } from '@mui/material';
import { Box, Container, Divider, Link, Stack, Tab, Typography } from '@mui/material';
import * as React from 'react';
import { useQuery } from 'react-query';
@ -41,7 +41,7 @@ export function OnnxWeb(props: OnnxWebProps) {
<Container>
<Box sx={{ my: 4 }}>
<Typography variant='h3' gutterBottom>
<a href='https://github.com/ssube/onnx-web'>ONNX Web</a>
<Link href='https://github.com/ssube/onnx-web' target='_blank' underline='hover'>ONNX Web</Link>
</Typography>
</Box>
<Box sx={{ mx: 4, my: 4 }}>

View File

@ -16,9 +16,9 @@ export function OutpaintControl(props: OutpaintControlProps) {
const { config } = props;
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.inpaint);
const params = useStore(state, (s) => s.outpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
const setOutpaint = useStore(state, (s) => s.setOutpaint);
return <Stack direction='row' spacing={4}>
<NumericField
@ -28,7 +28,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
step={config.width.step}
value={params.left}
onChange={(left) => {
setInpaint({
setOutpaint({
left,
});
}}
@ -40,7 +40,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
step={config.width.step}
value={params.right}
onChange={(right) => {
setInpaint({
setOutpaint({
right,
});
}}
@ -52,7 +52,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
step={config.height.step}
value={params.top}
onChange={(top) => {
setInpaint({
setOutpaint({
top,
});
}}
@ -64,7 +64,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
step={config.height.step}
value={params.bottom}
onChange={(bottom) => {
setInpaint({
setOutpaint({
bottom,
});
}}

View File

@ -3,7 +3,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils';
import { merge } from 'lodash';
import * as React from 'react';
import ReactDOM from 'react-dom/client';
import { QueryCache, QueryClient, QueryClientProvider } from 'react-query';
import { QueryClient, QueryClientProvider } from 'react-query';
import { createStore } from 'zustand';
import { createJSONStorage, persist } from 'zustand/middleware';
@ -48,6 +48,8 @@ export async function main() {
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
createBrushSlice,
createOutpaintSlice,
} = createStateSlices(params);
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
...createTxt2ImgSlice(...slice),
@ -55,6 +57,8 @@ export async function main() {
...createInpaintSlice(...slice),
...createHistorySlice(...slice),
...createDefaultSlice(...slice),
...createBrushSlice(...slice),
...createOutpaintSlice(...slice),
}), {
name: 'onnx-web',
partialize(s) {

View File

@ -7,8 +7,10 @@ import {
ApiClient,
ApiResponse,
BaseImgParams,
BrushParams,
Img2ImgParams,
InpaintParams,
OutpaintPixels,
paramsFromConfig,
Txt2ImgParams,
} from './api/client.js';
@ -54,7 +56,19 @@ interface DefaultSlice {
setDefaults(param: Partial<BaseImgParams>): void;
}
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice;
interface OutpaintSlice {
outpaint: OutpaintPixels;
setOutpaint(pixels: Partial<OutpaintPixels>): void;
}
interface BrushSlice {
brush: BrushParams;
setBrush(brush: Partial<BrushParams>): void;
}
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice & OutpaintSlice & BrushSlice;
export function createStateSlices(base: ConfigParams) {
const defaults = paramsFromConfig(base);
@ -114,10 +128,6 @@ export function createStateSlices(base: ConfigParams) {
...defaults,
mask: null,
source: null,
left: 0,
right: 0,
top: 0,
bottom: 0,
},
setInpaint(params) {
set((prev) => ({
@ -133,10 +143,6 @@ export function createStateSlices(base: ConfigParams) {
...defaults,
mask: null,
source: null,
left: 0,
right: 0,
top: 0,
bottom: 0,
},
});
},
@ -176,6 +182,39 @@ export function createStateSlices(base: ConfigParams) {
},
});
const createOutpaintSlice: StateCreator<OnnxState, [], [], OutpaintSlice> = (set) => ({
outpaint: {
left: 0,
right: 0,
top: 0,
bottom: 0,
},
setOutpaint(pixels) {
set((prev) => ({
outpaint: {
...prev.outpaint,
...pixels,
}
}));
},
});
const createBrushSlice: StateCreator<OnnxState, [], [], BrushSlice> = (set) => ({
brush: {
color: 255,
size: 8,
strength: 0.5,
},
setBrush(brush) {
set((prev) => ({
brush: {
...prev.brush,
...brush,
},
}));
},
});
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({
defaults: {
...defaults,
@ -196,6 +235,8 @@ export function createStateSlices(base: ConfigParams) {
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
createOutpaintSlice,
createBrushSlice,
};
}