From d1e4fa9cf1da5188b71d34d8a1753ba0e6186009 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 16 Jan 2023 14:05:04 -0600 Subject: [PATCH] feat: add upscale controls to client, params to server --- api/params.json | 12 +++++ gui/src/api/client.ts | 60 ++++++++++++++++----- gui/src/components/Img2Img.tsx | 2 + gui/src/components/Inpaint.tsx | 2 + gui/src/components/Txt2Img.tsx | 2 + gui/src/components/UpscaleControl.tsx | 78 +++++++++++++++++++++++++++ gui/src/config.ts | 4 +- gui/src/main.tsx | 14 ++--- gui/src/state.ts | 39 ++++++++++++-- 9 files changed, 188 insertions(+), 25 deletions(-) create mode 100644 gui/src/components/UpscaleControl.tsx diff --git a/api/params.json b/api/params.json index ab84b3f9..445fbab5 100644 --- a/api/params.json +++ b/api/params.json @@ -5,6 +5,12 @@ "max": 30, "step": 0.1 }, + "denoise": { + "default": 0.5, + "min": 0, + "max": 0, + "step": 0.1 + }, "height": { "default": 512, "min": 64, @@ -27,6 +33,12 @@ "default": "an astronaut eating a hamburger", "keys": [] }, + "scale": { + "default": 1, + "min": 1, + "max": 4, + "step": 1 + }, "scheduler": { "default": "euler-a", "keys": [] diff --git a/gui/src/api/client.ts b/gui/src/api/client.ts index 45912b94..e1fc196b 100644 --- a/gui/src/api/client.ts +++ b/gui/src/api/client.ts @@ -65,6 +65,14 @@ export interface BrushParams { strength: number; } +export interface UpscaleParams { + enabled: boolean; + + denoise: number; + faces: boolean; + scale: number; +} + export interface ApiResponse { output: { key: string; @@ -112,6 +120,9 @@ export function paramsFromConfig(defaults: ConfigParams): Required) { export function makeImageURL(root: string, type: string, params: BaseImgParams): URL { const url = makeApiUrl(root, type); - url.searchParams.append('cfg', params.cfg.toFixed(1)); - url.searchParams.append('steps', params.steps.toFixed(0)); + url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); + url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER)); if (doesExist(params.model)) { url.searchParams.append('model', params.model); @@ -142,7 +153,7 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): } if (doesExist(params.seed)) { - url.searchParams.append('seed', params.seed.toFixed(0)); + url.searchParams.append('seed', params.seed.toFixed(FIXED_INTEGER)); } // put prompt last, in case a load balancer decides to truncate the URL @@ -155,6 +166,12 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): return url; } +export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { + url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT)); + url.searchParams.append('faces', String(upscale.faces)); + url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER)); +} + export function makeClient(root: string, f = fetch): ApiClient { let pending: Promise | undefined; @@ -195,13 +212,17 @@ export function makeClient(root: string, f = fetch): ApiClient { const res = await f(path); return await res.json() as Array; }, - async img2img(params: Img2ImgParams): Promise { + async img2img(params: Img2ImgParams, upscale?: UpscaleParams): Promise { if (doesExist(pending)) { return pending; } const url = makeImageURL(root, 'img2img', params); - url.searchParams.append('strength', params.strength.toFixed(2)); + url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); + + if (doesExist(upscale)) { + appendUpscaleToURL(url, upscale); + } const body = new FormData(); body.append('source', params.source, 'source'); @@ -214,7 +235,7 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async txt2img(params: Txt2ImgParams): Promise { + async txt2img(params: Txt2ImgParams, upscale?: UpscaleParams): Promise { if (doesExist(pending)) { return pending; } @@ -222,11 +243,15 @@ export function makeClient(root: string, f = fetch): ApiClient { const url = makeImageURL(root, 'txt2img', params); if (doesExist(params.width)) { - url.searchParams.append('width', params.width.toFixed(0)); + url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER)); } if (doesExist(params.height)) { - url.searchParams.append('height', params.height.toFixed(0)); + url.searchParams.append('height', params.height.toFixed(FIXED_INTEGER)); + } + + if (doesExist(upscale)) { + appendUpscaleToURL(url, upscale); } pending = throttleRequest(url, { @@ -236,7 +261,7 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async inpaint(params: InpaintParams) { + async inpaint(params: InpaintParams, upscale?: UpscaleParams) { if (doesExist(pending)) { return pending; } @@ -244,6 +269,9 @@ export function makeClient(root: string, f = fetch): ApiClient { const url = makeImageURL(root, 'inpaint', params); url.searchParams.append('filter', params.filter); url.searchParams.append('noise', params.noise); + if (doesExist(upscale)) { + appendUpscaleToURL(url, upscale); + } const body = new FormData(); body.append('mask', params.mask, 'mask'); @@ -257,7 +285,7 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async outpaint(params: OutpaintParams) { + async outpaint(params: OutpaintParams, upscale?: UpscaleParams) { if (doesExist(pending)) { return pending; } @@ -266,20 +294,24 @@ export function makeClient(root: string, f = fetch): ApiClient { url.searchParams.append('filter', params.filter); url.searchParams.append('noise', params.noise); + if (doesExist(upscale)) { + appendUpscaleToURL(url, upscale); + } + if (doesExist(params.left)) { - url.searchParams.append('left', params.left.toFixed(0)); + url.searchParams.append('left', params.left.toFixed(FIXED_INTEGER)); } if (doesExist(params.right)) { - url.searchParams.append('right', params.right.toFixed(0)); + url.searchParams.append('right', params.right.toFixed(FIXED_INTEGER)); } if (doesExist(params.top)) { - url.searchParams.append('top', params.top.toFixed(0)); + url.searchParams.append('top', params.top.toFixed(FIXED_INTEGER)); } if (doesExist(params.bottom)) { - url.searchParams.append('bottom', params.bottom.toFixed(0)); + url.searchParams.append('bottom', params.bottom.toFixed(FIXED_INTEGER)); } const body = new FormData(); diff --git a/gui/src/components/Img2Img.tsx b/gui/src/components/Img2Img.tsx index 5ed91c4e..ff0d872e 100644 --- a/gui/src/components/Img2Img.tsx +++ b/gui/src/components/Img2Img.tsx @@ -9,6 +9,7 @@ import { ClientContext, StateContext } from '../state.js'; import { ImageControl } from './ImageControl.js'; import { ImageInput } from './ImageInput.js'; import { NumericField } from './NumericField.js'; +import { UpscaleControl } from './UpscaleControl.js'; const { useContext } = React; @@ -69,6 +70,7 @@ export function Img2Img(props: Img2ImgProps) { }); }} /> + ; diff --git a/gui/src/components/Inpaint.tsx b/gui/src/components/Inpaint.tsx index 65f120ec..5f67b900 100644 --- a/gui/src/components/Inpaint.tsx +++ b/gui/src/components/Inpaint.tsx @@ -12,6 +12,7 @@ import { ImageInput } from './ImageInput.js'; import { MaskCanvas } from './MaskCanvas.js'; import { OutpaintControl } from './OutpaintControl.js'; import { QueryList } from './QueryList.js'; +import { UpscaleControl } from './UpscaleControl.js'; const { useContext } = React; @@ -139,6 +140,7 @@ export function Inpaint(props: InpaintProps) { /> + ; diff --git a/gui/src/components/Txt2Img.tsx b/gui/src/components/Txt2Img.tsx index 75fd71fb..313b4517 100644 --- a/gui/src/components/Txt2Img.tsx +++ b/gui/src/components/Txt2Img.tsx @@ -8,6 +8,7 @@ import { ConfigParams } from '../config.js'; import { ClientContext, StateContext } from '../state.js'; import { ImageControl } from './ImageControl.js'; import { NumericField } from './NumericField.js'; +import { UpscaleControl } from './UpscaleControl.js'; const { useContext } = React; @@ -75,6 +76,7 @@ export function Txt2Img(props: Txt2ImgProps) { }} /> + ; diff --git a/gui/src/components/UpscaleControl.tsx b/gui/src/components/UpscaleControl.tsx new file mode 100644 index 00000000..d8f3d7db --- /dev/null +++ b/gui/src/components/UpscaleControl.tsx @@ -0,0 +1,78 @@ +import { mustExist } from '@apextoaster/js-utils'; +import { Check, FaceRetouchingNatural, ZoomIn } from '@mui/icons-material'; +import { Stack, ToggleButton } from '@mui/material'; +import * as React from 'react'; +import { useContext } from 'react'; +import { useStore } from 'zustand'; + +import { ConfigParams } from '../config.js'; +import { StateContext } from '../state.js'; +import { NumericField } from './NumericField.js'; + +export interface UpscaleControlProps { + config: ConfigParams; +} + +export function UpscaleControl(props: UpscaleControlProps) { + const { config } = props; + + const state = mustExist(useContext(StateContext)); + const params = useStore(state, (s) => s.upscale); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setUpscale = useStore(state, (s) => s.setUpscale); + + return + { + setUpscale({ + enabled: params.enabled === false, + }); + }} + > + + Upscale + + { + setUpscale({ + scale, + }); + }} + /> + { + setUpscale({ + denoise, + }); + }} + /> + { + setUpscale({ + faces: params.faces === false, + }); + }} + > + + Face Correction + + ; +} diff --git a/gui/src/config.ts b/gui/src/config.ts index 7e3f96ba..e2610c6f 100644 --- a/gui/src/config.ts +++ b/gui/src/config.ts @@ -1,6 +1,6 @@ import { Maybe } from '@apextoaster/js-utils'; -import { Img2ImgParams, STATUS_SUCCESS, Txt2ImgParams } from './api/client.js'; +import { Img2ImgParams, InpaintParams, OutpaintParams, STATUS_SUCCESS, Txt2ImgParams, UpscaleParams } from './api/client.js'; export interface ConfigNumber { default: number; @@ -30,7 +30,7 @@ export type ConfigState = { [K in KeyFilter]: T[K] extends TValid ? T[K] : never; }; -export type ConfigParams = ConfigRanges>; +export type ConfigParams = ConfigRanges>; export interface Config { api: { diff --git a/gui/src/main.tsx b/gui/src/main.tsx index 1acacd5b..02c2eca2 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -43,22 +43,24 @@ export async function main() { // prep zustand with a slice for each tab, using local storage const { + createBrushSlice, createDefaultSlice, createHistorySlice, createImg2ImgSlice, createInpaintSlice, - createTxt2ImgSlice, - createBrushSlice, createOutpaintSlice, + createTxt2ImgSlice, + createUpscaleSlice, } = createStateSlices(params); const state = createStore(persist((...slice) => ({ - ...createTxt2ImgSlice(...slice), + ...createBrushSlice(...slice), + ...createDefaultSlice(...slice), + ...createHistorySlice(...slice), ...createImg2ImgSlice(...slice), ...createInpaintSlice(...slice), - ...createHistorySlice(...slice), - ...createDefaultSlice(...slice), - ...createBrushSlice(...slice), + ...createTxt2ImgSlice(...slice), ...createOutpaintSlice(...slice), + ...createUpscaleSlice(...slice), }), { name: 'onnx-web', partialize(s) { diff --git a/gui/src/state.ts b/gui/src/state.ts index a77d8ab3..7f0073cd 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -13,6 +13,7 @@ import { OutpaintPixels, paramsFromConfig, Txt2ImgParams, + UpscaleParams, } from './api/client.js'; import { ConfigFiles, ConfigParams, ConfigState } from './config.js'; @@ -68,7 +69,21 @@ interface BrushSlice { setBrush(brush: Partial): void; } -export type OnnxState = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & HistorySlice & DefaultSlice & OutpaintSlice & BrushSlice; +interface UpscaleSlice { + upscale: UpscaleParams; + + setUpscale(upscale: Partial): void; +} + +export type OnnxState + = BrushSlice + & DefaultSlice + & HistorySlice + & Img2ImgSlice + & InpaintSlice + & OutpaintSlice + & Txt2ImgSlice + & UpscaleSlice; export function createStateSlices(base: ConfigParams) { const defaults = paramsFromConfig(base); @@ -220,6 +235,23 @@ export function createStateSlices(base: ConfigParams) { }, }); + const createUpscaleSlice: StateCreator = (set) => ({ + upscale: { + denoise: 0.5, + enabled: false, + faces: false, + scale: 1, + }, + setUpscale(upscale) { + set((prev) => ({ + upscale: { + ...prev.upscale, + ...upscale, + } + })); + }, + }); + const createDefaultSlice: StateCreator = (set) => ({ defaults: { ...defaults, @@ -235,13 +267,14 @@ export function createStateSlices(base: ConfigParams) { }); return { + createBrushSlice, createDefaultSlice, createHistorySlice, createImg2ImgSlice, createInpaintSlice, - createTxt2ImgSlice, createOutpaintSlice, - createBrushSlice, + createTxt2ImgSlice, + createUpscaleSlice, }; }