From 35e2e1dda643f5d44dd6a218535a4b85bc28c2e5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 12 Jan 2023 00:10:57 -0600 Subject: [PATCH] fix(gui): improve performance while using image controls --- gui/esbuild.js | 9 ++++ gui/package.json | 1 + gui/src/components/ImageHistory.tsx | 15 +++--- gui/src/components/Img2Img.tsx | 26 ++++++---- gui/src/components/Inpaint.tsx | 22 +++++--- gui/src/components/Txt2Img.tsx | 30 +++++++---- gui/src/main.tsx | 78 ++++++++++++++--------------- gui/src/state.ts | 45 +++++++++++++++++ gui/yarn.lock | 5 ++ 9 files changed, 159 insertions(+), 72 deletions(-) create mode 100644 gui/src/state.ts diff --git a/gui/esbuild.js b/gui/esbuild.js index fdae8508..f3060d4e 100644 --- a/gui/esbuild.js +++ b/gui/esbuild.js @@ -1,6 +1,9 @@ import { build } from 'esbuild'; +import { createRequire } from 'node:module'; import { join } from 'path'; +import alias from 'esbuild-plugin-alias'; +const require = createRequire(import.meta.url); const root = process.cwd(); build({ @@ -14,5 +17,11 @@ build({ keepNames: true, outdir: 'out/bundle/', platform: 'browser', + plugins: [ + alias({ + 'react-dom$': 'react-dom/profiling', + 'scheduler/tracing': 'scheduler/tracing-profiling', + }) + ], sourcemap: true, }).catch(() => process.exit(1)); diff --git a/gui/package.json b/gui/package.json index 29da79b2..8887c719 100644 --- a/gui/package.json +++ b/gui/package.json @@ -36,6 +36,7 @@ "chai": "^4.3.7", "chai-as-promised": "^7.1.1", "esbuild": "^0.16.14", + "esbuild-plugin-alias": "^0.2.1", "eslint": "^8.31.0", "eslint-plugin-chai": "^0.0.1", "eslint-plugin-chai-expect": "^3.0.0", diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index ff27a677..443dfebd 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -10,28 +10,31 @@ import { ImageCard } from './ImageCard.js'; import { LoadingCard } from './LoadingCard.js'; export function ImageHistory() { - const state = useStore(mustExist(useContext(StateContext))); - const { images } = state.history; + const history = useStore(mustExist(useContext(StateContext)), (state) => state.history); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setHistory = useStore(mustExist(useContext(StateContext)), (state) => state.setHistory); + + const { images } = history; const children = []; - if (state.history.loading) { + if (history.loading) { children.push(); // TODO: get dimensions from config } function removeHistory(image: ApiResponse) { - state.setHistory(images.filter((item) => image.output !== item.output)); + setHistory(images.filter((item) => image.output !== item.output)); } if (images.length > 0) { children.push(...images.map((item) => )); } else { - if (state.history.loading === false) { + if (history.loading === false) { children.push(
No results. Press Generate.
); } } - const limited = children.slice(0, state.history.limit); + const limited = children.slice(0, history.limit); return {limited.map((child, idx) => {child})}; } diff --git a/gui/src/components/Img2Img.tsx b/gui/src/components/Img2Img.tsx index d8fa2187..d0405245 100644 --- a/gui/src/components/Img2Img.tsx +++ b/gui/src/components/Img2Img.tsx @@ -23,30 +23,38 @@ export function Img2Img(props: Img2ImgProps) { const { config, model, platform } = props; async function uploadSource() { - state.setLoading(true); + setLoading(true); const output = await client.img2img({ - ...state.img2img, + ...params, model, platform, source: mustExist(source), // TODO: show an error if this doesn't exist }); - state.pushHistory(output); - state.setLoading(false); + pushHistory(output); + setLoading(false); } const client = mustExist(useContext(ClientContext)); const upload = useMutation(uploadSource); - const state = useStore(mustExist(useContext(StateContext))); + + const state = mustExist(useContext(StateContext)); + const params = useStore(state, (s) => s.img2img); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setImg2Img = useStore(state, (s) => s.setImg2Img); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setLoading = useStore(state, (s) => s.setLoading); + // eslint-disable-next-line @typescript-eslint/unbound-method + const pushHistory = useStore(state, (s) => s.pushHistory); const [source, setSource] = useState(); return - { - state.setImg2Img(newParams); + { + setImg2Img(newParams); }} /> { - state.setImg2Img({ + setImg2Img({ strength: value, }); }} diff --git a/gui/src/components/Inpaint.tsx b/gui/src/components/Inpaint.tsx index 8443e0c8..de175e87 100644 --- a/gui/src/components/Inpaint.tsx +++ b/gui/src/components/Inpaint.tsx @@ -70,18 +70,18 @@ export function Inpaint(props: InpaintProps) { async function uploadSource() { const canvas = mustExist(canvasRef.current); - state.setLoading(true); + setLoading(true); return new Promise((res, rej) => { canvas.toBlob((blob) => { client.inpaint({ - ...state.inpaint, + ...params, model, platform, mask: mustExist(blob), source: mustExist(source), }).then((output) => { - state.pushHistory(output); - state.setLoading(false); + pushHistory(output); + setLoading(false); res(); }).catch((err) => rej(err)); }); @@ -138,10 +138,18 @@ export function Inpaint(props: InpaintProps) { ctx.putImageData(image, 0, 0); } + const state = mustExist(useContext(StateContext)); + const params = useStore(state, (s) => s.inpaint); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setInpaint = useStore(state, (s) => s.setInpaint); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setLoading = useStore(state, (s) => s.setLoading); + // eslint-disable-next-line @typescript-eslint/unbound-method + const pushHistory = useStore(state, (s) => s.pushHistory); + const upload = useMutation(uploadSource); // eslint-disable-next-line no-null/no-null const canvasRef = useRef(null); - const state = useStore(mustExist(useContext(StateContext))); // painting state const [clicks, setClicks] = useState>([]); @@ -259,9 +267,9 @@ export function Inpaint(props: InpaintProps) { { - state.setInpaint(newParams); + setInpaint(newParams); }} /> diff --git a/gui/src/components/Txt2Img.tsx b/gui/src/components/Txt2Img.tsx index b21bbc13..ca358901 100644 --- a/gui/src/components/Txt2Img.tsx +++ b/gui/src/components/Txt2Img.tsx @@ -22,26 +22,34 @@ export function Txt2Img(props: Txt2ImgProps) { const { config, model, platform } = props; async function generateImage() { - state.setLoading(true); + setLoading(true); const output = await client.txt2img({ - ...state.txt2img, + ...params, model, platform, }); - state.pushHistory(output); - state.setLoading(false); + pushHistory(output); + setLoading(false); } const client = mustExist(useContext(ClientContext)); const generate = useMutation(generateImage); - const state = useStore(mustExist(useContext(StateContext))); + + const state = mustExist(useContext(StateContext)); + const params = useStore(state, (s) => s.txt2img); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setTxt2Img = useStore(state, (s) => s.setTxt2Img); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setLoading = useStore(state, (s) => s.setLoading); + // eslint-disable-next-line @typescript-eslint/unbound-method + const pushHistory = useStore(state, (s) => s.pushHistory); return - { - state.setTxt2Img(newParams); + { + setTxt2Img(newParams); }} /> { - state.setTxt2Img({ + setTxt2Img({ width: value, }); }} @@ -61,9 +69,9 @@ export function Txt2Img(props: Txt2ImgProps) { min={config.height.min} max={config.height.max} step={config.height.step} - value={state.txt2img.height} + value={params.height} onChange={(value) => { - state.setTxt2Img({ + setTxt2Img({ height: value, }); }} diff --git a/gui/src/main.tsx b/gui/src/main.tsx index 292e7a1a..e3bf5118 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -66,45 +66,6 @@ export async function main() { inpaint: { ...defaults, }, - setLimit(limit) { - set((oldState) => ({ - ...oldState, - history: { - ...oldState.history, - limit, - }, - })); - }, - setLoading(loading) { - set((oldState) => ({ - ...oldState, - history: { - ...oldState.history, - loading, - }, - })); - }, - pushHistory(newImage: ApiResponse) { - set((oldState) => ({ - ...oldState, - history: { - ...oldState.history, - images: [ - newImage, - ...oldState.history.images, - ].slice(0, oldState.history.limit), - }, - })); - }, - setHistory(newHistory: Array) { - set((oldState) => ({ - ...oldState, - history: { - ...oldState.history, - images: newHistory, - }, - })); - }, setDefaults(newParams) { set((oldState) => ({ ...oldState, @@ -168,6 +129,45 @@ export async function main() { }, })); }, + setLimit(limit) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + limit, + }, + })); + }, + setLoading(loading) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + loading, + }, + })); + }, + pushHistory(newImage: ApiResponse) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + images: [ + newImage, + ...oldState.history.images, + ].slice(0, oldState.history.limit), + }, + })); + }, + setHistory(newHistory: Array) { + set((oldState) => ({ + ...oldState, + history: { + ...oldState.history, + images: newHistory, + }, + })); + }, }), { name: 'onnx-web', partialize: (oldState) => ({ diff --git a/gui/src/state.ts b/gui/src/state.ts new file mode 100644 index 00000000..7ef2fe4f --- /dev/null +++ b/gui/src/state.ts @@ -0,0 +1,45 @@ +import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, Txt2ImgParams } from './api/client.js'; +import { ConfigState } from './config.js'; + +interface TabState { + params: ConfigState>; + + reset(): void; + update(params: Partial>>): void; +} + +interface OnnxState { + defaults: { + params: Required; + update(newParams: Partial): void; + }; + txt2img: { + params: ConfigState>; + + reset(): void; + update(newParams: Partial>>): void; + }; + img2img: { + params: ConfigState>; + + reset(): void; + update(newParams: Partial>>): void; + }; + inpaint: { + params: ConfigState>; + + reset(): void; + update(newParams: Partial>>): void; + }; + history: { + images: Array; + limit: number; + loading: boolean; + + setLimit(limit: number): void; + setLoading(loading: boolean): void; + setHistory(newHistory: Array): void; + pushHistory(newImage: ApiResponse): void; + }; +} + diff --git a/gui/yarn.lock b/gui/yarn.lock index 987b12f7..7798537c 100644 --- a/gui/yarn.lock +++ b/gui/yarn.lock @@ -1212,6 +1212,11 @@ es-to-primitive@^1.2.1: is-date-object "^1.0.1" is-symbol "^1.0.2" +esbuild-plugin-alias@^0.2.1: + version "0.2.1" + resolved "https://registry.yarnpkg.com/esbuild-plugin-alias/-/esbuild-plugin-alias-0.2.1.tgz#45a86cb941e20e7c2bc68a2bea53562172494fcb" + integrity sha512-jyfL/pwPqaFXyKnj8lP8iLk6Z0m099uXR45aSN8Av1XD4vhvQutxxPzgA2bTcAwQpa1zCXDcWOlhFgyP3GKqhQ== + esbuild@^0.16.14: version "0.16.14" resolved "https://registry.yarnpkg.com/esbuild/-/esbuild-0.16.14.tgz#366249a0a0fd431d3ab706195721ef1014198919"