From 6cd98bb96002f3177c08eeec4fa174b1cb7156ae Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 14 Jan 2023 18:07:05 -0600 Subject: [PATCH] feat(gui): add outpainting to API client and state --- gui/src/api/client.ts | 39 ++++++++++++++--- gui/src/components/Inpaint.tsx | 33 ++++++++++---- gui/src/components/MaskCanvas.tsx | 46 +++++++++++--------- gui/src/components/OnnxError.tsx | 4 +- gui/src/components/OnnxWeb.tsx | 4 +- gui/src/components/OutpaintControl.tsx | 12 +++--- gui/src/main.tsx | 6 ++- gui/src/state.ts | 59 ++++++++++++++++++++++---- 8 files changed, 148 insertions(+), 55 deletions(-) diff --git a/gui/src/api/client.ts b/gui/src/api/client.ts index b70d96ca..f7513ed7 100644 --- a/gui/src/api/client.ts +++ b/gui/src/api/client.ts @@ -43,11 +43,21 @@ export type Txt2ImgResponse = Required; export interface InpaintParams extends BaseImgParams { mask: Blob; source: Blob; +} - left?: number; - right?: number; - top?: number; - bottom?: number; +export interface OutpaintPixels { + left: number; + right: number; + top: number; + bottom: number; +} + +export type OutpaintParams = InpaintParams & OutpaintPixels; + +export interface BrushParams { + color: number; + size: number; + strength: number; } export interface ApiResponse { @@ -71,6 +81,7 @@ export interface ApiClient { img2img(params: Img2ImgParams): Promise; txt2img(params: Txt2ImgParams): Promise; inpaint(params: InpaintParams): Promise; + outpaint(params: OutpaintParams): Promise; ready(params: ApiResponse): Promise; } @@ -211,11 +222,29 @@ export function makeClient(root: string, f = fetch): ApiClient { const url = makeImageURL(root, 'inpaint', params); + const body = new FormData(); + body.append('mask', params.mask, 'mask'); + body.append('source', params.source, 'source'); + + pending = throttleRequest(url, { + body, + method: 'POST', + }); + + // eslint-disable-next-line no-return-await + return await pending; + }, + async outpaint(params: OutpaintParams) { + if (doesExist(pending)) { + return pending; + } + + const url = makeImageURL(root, 'inpaint', params); + if (doesExist(params.left)) { url.searchParams.append('left', params.left.toFixed(0)); } - if (doesExist(params.right)) { url.searchParams.append('right', params.right.toFixed(0)); } diff --git a/gui/src/components/Inpaint.tsx b/gui/src/components/Inpaint.tsx index f441937d..f4f85f65 100644 --- a/gui/src/components/Inpaint.tsx +++ b/gui/src/components/Inpaint.tsx @@ -1,4 +1,4 @@ -import { doesExist, mustExist } from '@apextoaster/js-utils'; +import { mustExist } from '@apextoaster/js-utils'; import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; import { useMutation, useQueryClient } from 'react-query'; @@ -25,15 +25,30 @@ export function Inpaint(props: InpaintProps) { const client = mustExist(useContext(ClientContext)); async function uploadSource(): Promise { - const output = await client.inpaint({ - ...params, - model, - platform, - mask: mustExist(params.mask), - source: mustExist(params.source), - }); + const outpaint = state.getState().outpaint; // TODO: seems shady - setLoading(output); + if (outpaint.bottom > 0 || outpaint.left > 0 || outpaint.right > 0 || outpaint.top > 0) { + const output = await client.outpaint({ + ...params, + ...outpaint, + model, + platform, + mask: mustExist(params.mask), + source: mustExist(params.source), + }); + + setLoading(output); + } else { + const output = await client.inpaint({ + ...params, + model, + platform, + mask: mustExist(params.mask), + source: mustExist(params.source), + }); + + setLoading(output); + } } const state = mustExist(useContext(StateContext)); diff --git a/gui/src/components/MaskCanvas.tsx b/gui/src/components/MaskCanvas.tsx index d2765e71..6e0f95f6 100644 --- a/gui/src/components/MaskCanvas.tsx +++ b/gui/src/components/MaskCanvas.tsx @@ -2,9 +2,11 @@ import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils'; import { FormatColorFill, Gradient } from '@mui/icons-material'; import { Button, Stack } from '@mui/material'; import { throttle } from 'lodash'; -import React, { RefObject, useEffect, useMemo, useRef, useState } from 'react'; +import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react'; +import { useStore } from 'zustand'; -import { ConfigParams, DEFAULT_BRUSH, SAVE_TIME } from '../config.js'; +import { ConfigParams, SAVE_TIME } from '../config.js'; +import { StateContext } from '../state.js'; import { NumericField } from './NumericField'; export const FULL_CIRCLE = 2 * Math.PI; @@ -111,18 +113,20 @@ export function MaskCanvas(props: MaskCanvasProps) { const maskState = useRef(MASK_STATE.clean); const [background, setBackground] = useState(); const [clicks, setClicks] = useState>([]); - const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color); - const [brushOpacity, setBrushOpacity] = useState(1.0); - const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size); + + const state = mustExist(useContext(StateContext)); + const brush = useStore(state, (s) => s.brush); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setBrush = useStore(state, (s) => s.setBrush); useEffect(() => { // including clicks.length prevents the initial render from saving a blank canvas if (doesExist(bufferRef.current) && maskState.current === MASK_STATE.painting && clicks.length > 0) { const { ctx } = getContext(bufferRef); - ctx.fillStyle = grayToRGB(brushColor, brushOpacity); + ctx.fillStyle = grayToRGB(brush.color, brush.strength); for (const click of clicks) { - drawCircle(ctx, click, brushSize); + drawCircle(ctx, click, brush.size); } clicks.length = 0; @@ -188,12 +192,12 @@ export function MaskCanvas(props: MaskCanvasProps) { const bounds = canvas.getBoundingClientRect(); const { ctx } = getContext(bufferRef); - ctx.fillStyle = grayToRGB(brushColor, brushOpacity); + ctx.fillStyle = grayToRGB(brush.color, brush.strength); drawCircle(ctx, { x: event.clientX - bounds.left, y: event.clientY - bounds.top, - }, brushSize); + }, brush.size); maskState.current = MASK_STATE.dirty; save(); @@ -215,12 +219,12 @@ export function MaskCanvas(props: MaskCanvasProps) { }]); } else { const { ctx } = getClearContext(brushRef); - ctx.fillStyle = grayToRGB(brushColor, brushOpacity); + ctx.fillStyle = grayToRGB(brush.color, brush.strength); drawCircle(ctx, { x: event.clientX - bounds.left, y: event.clientY - bounds.top, - }, brushSize); + }, brush.size); drawBuffer(); } @@ -228,13 +232,13 @@ export function MaskCanvas(props: MaskCanvasProps) { /> { - setBrushColor(value); + value={brush.color} + onChange={(color) => { + setBrush({ color }); }} /> { - setBrushSize(value); + value={brush.size} + onChange={(size) => { + setBrush({ size }); }} /> { - setBrushOpacity(value); + value={brush.strength} + onChange={(strength) => { + setBrush({ strength }); }} />