From d19bbfc1d39c40915b56adc1bfe1ec62d975eb47 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 28 Mar 2023 17:51:40 -0500 Subject: [PATCH] fix(gui): add prompt tokens to correct tab (#296) --- gui/src/components/OnnxWeb.tsx | 25 +----------- gui/src/components/control/ModelControl.tsx | 43 +++++++++++++++------ gui/src/components/utils.ts | 19 +++++++++ gui/src/utils.ts | 8 ++++ 4 files changed, 60 insertions(+), 35 deletions(-) create mode 100644 gui/src/components/utils.ts diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index 53eb237a..6eee4329 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -13,32 +13,11 @@ import { Inpaint } from './tab/Inpaint.js'; import { Settings } from './tab/Settings.js'; import { Txt2Img } from './tab/Txt2Img.js'; import { Upscale } from './tab/Upscale.js'; - -const REMOVE_HASH = /^#?(.*)$/; -const TAB_LABELS = [ - 'txt2img', - 'img2img', - 'inpaint', - 'upscale', - 'blend', - 'settings', -]; +import { getTab, TAB_LABELS } from './utils.js'; export function OnnxWeb() { const [hash, setHash] = useHash(); - function tab(): string { - const match = hash.match(REMOVE_HASH); - if (doesExist(match)) { - const [_full, route] = Array.from(match); - if (route.length > 0) { - return route; - } - } - - return TAB_LABELS[0]; - } - return ( @@ -47,7 +26,7 @@ export function OnnxWeb() { - + { setHash(idx); diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index 01ec3181..05dd412d 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -4,12 +4,14 @@ import * as React from 'react'; import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useQuery } from 'react-query'; +import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { STALE_TIME } from '../../config.js'; import { ClientContext, StateContext } from '../../state.js'; import { QueryList } from '../input/QueryList.js'; import { QueryMenu } from '../input/QueryMenu.js'; +import { getTab } from '../utils.js'; export function ModelControl() { const client = mustExist(useContext(ClientContext)); @@ -26,6 +28,33 @@ export function ModelControl() { staleTime: STALE_TIME, }); + const [hash, _setHash] = useHash(); + + function addToken(type: string, name: string, weight = 1.0) { + const tab = getTab(hash); + const current = state.getState(); + + + switch (tab) { + case 'txt2img': { + const { prompt } = current.txt2img; + current.setTxt2Img({ + prompt: `<${type}:${name}:1.0> ${prompt}`, + }); + break; + } + case 'img2img': { + const { prompt } = current.img2img; + current.setImg2Img({ + prompt: `<${type}:${name}:1.0> ${prompt}`, + }); + break; + } + default: + // not supported yet + } + } + return result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), }} onSelect={(name) => { - const current = state.getState(); - const { prompt } = current.txt2img; - - current.setTxt2Img({ - prompt: ` ${prompt}`, - }); + addToken('inversion', name); }} /> result.networks.filter((network) => network.type === 'lora').map((network) => network.name), }} onSelect={(name) => { - const current = state.getState(); - const { prompt } = current.txt2img; - - current.setTxt2Img({ - prompt: ` ${prompt}`, - }); + addToken('lora', name); }} /> diff --git a/gui/src/components/utils.ts b/gui/src/components/utils.ts new file mode 100644 index 00000000..db59ffbd --- /dev/null +++ b/gui/src/components/utils.ts @@ -0,0 +1,19 @@ +import { trimHash } from '../utils.js'; + +export const TAB_LABELS = [ + 'txt2img', + 'img2img', + 'inpaint', + 'upscale', + 'blend', + 'settings', +] as const; + +export function getTab(hash: string): string { + const route = trimHash(hash); + if (route.length > 0) { + return route; + } + + return TAB_LABELS[0]; +} diff --git a/gui/src/utils.ts b/gui/src/utils.ts index a209721c..6fbdbab9 100644 --- a/gui/src/utils.ts +++ b/gui/src/utils.ts @@ -18,3 +18,11 @@ export function range(max: number): Array { export function visibleIndex(idx: number): string { return (idx + 1).toFixed(0); } + +export function trimHash(val: string): string { + if (val[0] === '#') { + return val.slice(1); + } + + return val; +}