feat(gui): add outpainting to API client and state
This commit is contained in:
parent
34fa3f6341
commit
6cd98bb960
|
@ -43,11 +43,21 @@ export type Txt2ImgResponse = Required<Txt2ImgParams>;
|
|||
export interface InpaintParams extends BaseImgParams {
|
||||
mask: Blob;
|
||||
source: Blob;
|
||||
}
|
||||
|
||||
left?: number;
|
||||
right?: number;
|
||||
top?: number;
|
||||
bottom?: number;
|
||||
export interface OutpaintPixels {
|
||||
left: number;
|
||||
right: number;
|
||||
top: number;
|
||||
bottom: number;
|
||||
}
|
||||
|
||||
export type OutpaintParams = InpaintParams & OutpaintPixels;
|
||||
|
||||
export interface BrushParams {
|
||||
color: number;
|
||||
size: number;
|
||||
strength: number;
|
||||
}
|
||||
|
||||
export interface ApiResponse {
|
||||
|
@ -71,6 +81,7 @@ export interface ApiClient {
|
|||
img2img(params: Img2ImgParams): Promise<ApiResponse>;
|
||||
txt2img(params: Txt2ImgParams): Promise<ApiResponse>;
|
||||
inpaint(params: InpaintParams): Promise<ApiResponse>;
|
||||
outpaint(params: OutpaintParams): Promise<ApiResponse>;
|
||||
|
||||
ready(params: ApiResponse): Promise<ApiReady>;
|
||||
}
|
||||
|
@ -211,11 +222,29 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
|
||||
const url = makeImageURL(root, 'inpaint', params);
|
||||
|
||||
const body = new FormData();
|
||||
body.append('mask', params.mask, 'mask');
|
||||
body.append('source', params.source, 'source');
|
||||
|
||||
pending = throttleRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
|
||||
// eslint-disable-next-line no-return-await
|
||||
return await pending;
|
||||
},
|
||||
async outpaint(params: OutpaintParams) {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
||||
const url = makeImageURL(root, 'inpaint', params);
|
||||
|
||||
if (doesExist(params.left)) {
|
||||
url.searchParams.append('left', params.left.toFixed(0));
|
||||
}
|
||||
|
||||
|
||||
if (doesExist(params.right)) {
|
||||
url.searchParams.append('right', params.right.toFixed(0));
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation, useQueryClient } from 'react-query';
|
||||
|
@ -25,15 +25,30 @@ export function Inpaint(props: InpaintProps) {
|
|||
const client = mustExist(useContext(ClientContext));
|
||||
|
||||
async function uploadSource(): Promise<void> {
|
||||
const output = await client.inpaint({
|
||||
...params,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(params.mask),
|
||||
source: mustExist(params.source),
|
||||
});
|
||||
const outpaint = state.getState().outpaint; // TODO: seems shady
|
||||
|
||||
setLoading(output);
|
||||
if (outpaint.bottom > 0 || outpaint.left > 0 || outpaint.right > 0 || outpaint.top > 0) {
|
||||
const output = await client.outpaint({
|
||||
...params,
|
||||
...outpaint,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(params.mask),
|
||||
source: mustExist(params.source),
|
||||
});
|
||||
|
||||
setLoading(output);
|
||||
} else {
|
||||
const output = await client.inpaint({
|
||||
...params,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(params.mask),
|
||||
source: mustExist(params.source),
|
||||
});
|
||||
|
||||
setLoading(output);
|
||||
}
|
||||
}
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
|
|
|
@ -2,9 +2,11 @@ import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
|
|||
import { FormatColorFill, Gradient } from '@mui/icons-material';
|
||||
import { Button, Stack } from '@mui/material';
|
||||
import { throttle } from 'lodash';
|
||||
import React, { RefObject, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams, DEFAULT_BRUSH, SAVE_TIME } from '../config.js';
|
||||
import { ConfigParams, SAVE_TIME } from '../config.js';
|
||||
import { StateContext } from '../state.js';
|
||||
import { NumericField } from './NumericField';
|
||||
|
||||
export const FULL_CIRCLE = 2 * Math.PI;
|
||||
|
@ -111,18 +113,20 @@ export function MaskCanvas(props: MaskCanvasProps) {
|
|||
const maskState = useRef(MASK_STATE.clean);
|
||||
const [background, setBackground] = useState<string>();
|
||||
const [clicks, setClicks] = useState<Array<Point>>([]);
|
||||
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
|
||||
const [brushOpacity, setBrushOpacity] = useState(1.0);
|
||||
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const brush = useStore(state, (s) => s.brush);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setBrush = useStore(state, (s) => s.setBrush);
|
||||
|
||||
useEffect(() => {
|
||||
// including clicks.length prevents the initial render from saving a blank canvas
|
||||
if (doesExist(bufferRef.current) && maskState.current === MASK_STATE.painting && clicks.length > 0) {
|
||||
const { ctx } = getContext(bufferRef);
|
||||
ctx.fillStyle = grayToRGB(brushColor, brushOpacity);
|
||||
ctx.fillStyle = grayToRGB(brush.color, brush.strength);
|
||||
|
||||
for (const click of clicks) {
|
||||
drawCircle(ctx, click, brushSize);
|
||||
drawCircle(ctx, click, brush.size);
|
||||
}
|
||||
|
||||
clicks.length = 0;
|
||||
|
@ -188,12 +192,12 @@ export function MaskCanvas(props: MaskCanvasProps) {
|
|||
const bounds = canvas.getBoundingClientRect();
|
||||
|
||||
const { ctx } = getContext(bufferRef);
|
||||
ctx.fillStyle = grayToRGB(brushColor, brushOpacity);
|
||||
ctx.fillStyle = grayToRGB(brush.color, brush.strength);
|
||||
|
||||
drawCircle(ctx, {
|
||||
x: event.clientX - bounds.left,
|
||||
y: event.clientY - bounds.top,
|
||||
}, brushSize);
|
||||
}, brush.size);
|
||||
|
||||
maskState.current = MASK_STATE.dirty;
|
||||
save();
|
||||
|
@ -215,12 +219,12 @@ export function MaskCanvas(props: MaskCanvasProps) {
|
|||
}]);
|
||||
} else {
|
||||
const { ctx } = getClearContext(brushRef);
|
||||
ctx.fillStyle = grayToRGB(brushColor, brushOpacity);
|
||||
ctx.fillStyle = grayToRGB(brush.color, brush.strength);
|
||||
|
||||
drawCircle(ctx, {
|
||||
x: event.clientX - bounds.left,
|
||||
y: event.clientY - bounds.top,
|
||||
}, brushSize);
|
||||
}, brush.size);
|
||||
|
||||
drawBuffer();
|
||||
}
|
||||
|
@ -228,13 +232,13 @@ export function MaskCanvas(props: MaskCanvasProps) {
|
|||
/>
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
label='Brush Shade'
|
||||
label='Brush Color'
|
||||
min={0}
|
||||
max={255}
|
||||
step={1}
|
||||
value={brushColor}
|
||||
onChange={(value) => {
|
||||
setBrushColor(value);
|
||||
value={brush.color}
|
||||
onChange={(color) => {
|
||||
setBrush({ color });
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
|
@ -242,9 +246,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
|
|||
min={4}
|
||||
max={64}
|
||||
step={1}
|
||||
value={brushSize}
|
||||
onChange={(value) => {
|
||||
setBrushSize(value);
|
||||
value={brush.size}
|
||||
onChange={(size) => {
|
||||
setBrush({ size });
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
|
@ -253,9 +257,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
|
|||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
value={brushOpacity}
|
||||
onChange={(value) => {
|
||||
setBrushOpacity(value);
|
||||
value={brush.strength}
|
||||
onChange={(strength) => {
|
||||
setBrush({ strength });
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { Alert, AlertTitle, Box, Container, Stack, Typography } from '@mui/material';
|
||||
import { Alert, AlertTitle, Box, Container, Link, Stack, Typography } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
|
||||
export interface OnnxErrorProps {
|
||||
|
@ -12,7 +12,7 @@ export function OnnxError(props: OnnxErrorProps) {
|
|||
<Container>
|
||||
<Box sx={{ my: 4 }}>
|
||||
<Typography variant='h3' gutterBottom>
|
||||
<a href='https://github.com/ssube/onnx-web'>ONNX Web</a>
|
||||
<Link href='https://github.com/ssube/onnx-web' target='_blank' underline='hover'>ONNX Web</Link>
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box sx={{ my: 4 }}>
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { TabContext, TabList, TabPanel } from '@mui/lab';
|
||||
import { Box, Container, Divider, Stack, Tab, Typography } from '@mui/material';
|
||||
import { Box, Container, Divider, Link, Stack, Tab, Typography } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useQuery } from 'react-query';
|
||||
|
||||
|
@ -41,7 +41,7 @@ export function OnnxWeb(props: OnnxWebProps) {
|
|||
<Container>
|
||||
<Box sx={{ my: 4 }}>
|
||||
<Typography variant='h3' gutterBottom>
|
||||
<a href='https://github.com/ssube/onnx-web'>ONNX Web</a>
|
||||
<Link href='https://github.com/ssube/onnx-web' target='_blank' underline='hover'>ONNX Web</Link>
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box sx={{ mx: 4, my: 4 }}>
|
||||
|
|
|
@ -16,9 +16,9 @@ export function OutpaintControl(props: OutpaintControlProps) {
|
|||
const { config } = props;
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const params = useStore(state, (s) => s.inpaint);
|
||||
const params = useStore(state, (s) => s.outpaint);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setInpaint = useStore(state, (s) => s.setInpaint);
|
||||
const setOutpaint = useStore(state, (s) => s.setOutpaint);
|
||||
|
||||
return <Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
|
@ -28,7 +28,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
|
|||
step={config.width.step}
|
||||
value={params.left}
|
||||
onChange={(left) => {
|
||||
setInpaint({
|
||||
setOutpaint({
|
||||
left,
|
||||
});
|
||||
}}
|
||||
|
@ -40,7 +40,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
|
|||
step={config.width.step}
|
||||
value={params.right}
|
||||
onChange={(right) => {
|
||||
setInpaint({
|
||||
setOutpaint({
|
||||
right,
|
||||
});
|
||||
}}
|
||||
|
@ -52,7 +52,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
|
|||
step={config.height.step}
|
||||
value={params.top}
|
||||
onChange={(top) => {
|
||||
setInpaint({
|
||||
setOutpaint({
|
||||
top,
|
||||
});
|
||||
}}
|
||||
|
@ -64,7 +64,7 @@ export function OutpaintControl(props: OutpaintControlProps) {
|
|||
step={config.height.step}
|
||||
value={params.bottom}
|
||||
onChange={(bottom) => {
|
||||
setInpaint({
|
||||
setOutpaint({
|
||||
bottom,
|
||||
});
|
||||
}}
|
||||
|
|
|
@ -3,7 +3,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils';
|
|||
import { merge } from 'lodash';
|
||||
import * as React from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import { QueryCache, QueryClient, QueryClientProvider } from 'react-query';
|
||||
import { QueryClient, QueryClientProvider } from 'react-query';
|
||||
import { createStore } from 'zustand';
|
||||
import { createJSONStorage, persist } from 'zustand/middleware';
|
||||
|
||||
|
@ -48,6 +48,8 @@ export async function main() {
|
|||
createImg2ImgSlice,
|
||||
createInpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
createBrushSlice,
|
||||
createOutpaintSlice,
|
||||
} = createStateSlices(params);
|
||||
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
|
||||
...createTxt2ImgSlice(...slice),
|
||||
|
@ -55,6 +57,8 @@ export async function main() {
|
|||
...createInpaintSlice(...slice),
|
||||
...createHistorySlice(...slice),
|
||||
...createDefaultSlice(...slice),
|
||||
...createBrushSlice(...slice),
|
||||
...createOutpaintSlice(...slice),
|
||||
}), {
|
||||
name: 'onnx-web',
|
||||
partialize(s) {
|
||||
|
|
|
@ -7,8 +7,10 @@ import {
|
|||
ApiClient,
|
||||
ApiResponse,
|
||||
BaseImgParams,
|
||||
BrushParams,
|
||||
Img2ImgParams,
|
||||
InpaintParams,
|
||||
OutpaintPixels,
|
||||
paramsFromConfig,
|
||||
Txt2ImgParams,
|
||||
} from './api/client.js';
|
||||
|
@ -54,7 +56,19 @@ interface DefaultSlice {
|
|||
setDefaults(param: Partial<BaseImgParams>): void;
|
||||
}
|
||||
|
||||
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice;
|
||||
interface OutpaintSlice {
|
||||
outpaint: OutpaintPixels;
|
||||
|
||||
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
||||
}
|
||||
|
||||
interface BrushSlice {
|
||||
brush: BrushParams;
|
||||
|
||||
setBrush(brush: Partial<BrushParams>): void;
|
||||
}
|
||||
|
||||
export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice & OutpaintSlice & BrushSlice;
|
||||
|
||||
export function createStateSlices(base: ConfigParams) {
|
||||
const defaults = paramsFromConfig(base);
|
||||
|
@ -114,10 +128,6 @@ export function createStateSlices(base: ConfigParams) {
|
|||
...defaults,
|
||||
mask: null,
|
||||
source: null,
|
||||
left: 0,
|
||||
right: 0,
|
||||
top: 0,
|
||||
bottom: 0,
|
||||
},
|
||||
setInpaint(params) {
|
||||
set((prev) => ({
|
||||
|
@ -133,10 +143,6 @@ export function createStateSlices(base: ConfigParams) {
|
|||
...defaults,
|
||||
mask: null,
|
||||
source: null,
|
||||
left: 0,
|
||||
right: 0,
|
||||
top: 0,
|
||||
bottom: 0,
|
||||
},
|
||||
});
|
||||
},
|
||||
|
@ -176,6 +182,39 @@ export function createStateSlices(base: ConfigParams) {
|
|||
},
|
||||
});
|
||||
|
||||
const createOutpaintSlice: StateCreator<OnnxState, [], [], OutpaintSlice> = (set) => ({
|
||||
outpaint: {
|
||||
left: 0,
|
||||
right: 0,
|
||||
top: 0,
|
||||
bottom: 0,
|
||||
},
|
||||
setOutpaint(pixels) {
|
||||
set((prev) => ({
|
||||
outpaint: {
|
||||
...prev.outpaint,
|
||||
...pixels,
|
||||
}
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
const createBrushSlice: StateCreator<OnnxState, [], [], BrushSlice> = (set) => ({
|
||||
brush: {
|
||||
color: 255,
|
||||
size: 8,
|
||||
strength: 0.5,
|
||||
},
|
||||
setBrush(brush) {
|
||||
set((prev) => ({
|
||||
brush: {
|
||||
...prev.brush,
|
||||
...brush,
|
||||
},
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({
|
||||
defaults: {
|
||||
...defaults,
|
||||
|
@ -196,6 +235,8 @@ export function createStateSlices(base: ConfigParams) {
|
|||
createImg2ImgSlice,
|
||||
createInpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
createOutpaintSlice,
|
||||
createBrushSlice,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue