From e0929ba87098afee294985db20a30dfe0f085793 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 16 Dec 2023 23:19:10 -0600 Subject: [PATCH] fix(gui): make prompt input perform better with large LoRA/wildcard lists --- gui/src/components/input/PromptInput.tsx | 92 +++++++++++++++--------- 1 file changed, 59 insertions(+), 33 deletions(-) diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index a8081222..12313e89 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -10,7 +10,7 @@ import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { QueryMenu } from '../input/QueryMenu.js'; -import { ModelResponse } from '../../types/api.js'; +import { ModelResponse, NetworkModel } from '../../types/api.js'; const { useContext, useMemo } = React; @@ -27,31 +27,19 @@ export interface PromptInputProps { onChange(value: PromptValue): void; } +export interface PromptTextBlockProps extends PromptInputProps { + models: Maybe; +} + export const PROMPT_GROUP = 75; -export function PromptInput(props: PromptInputProps) { +export function PromptTextBlock(props: PromptTextBlockProps) { // eslint-disable-next-line @typescript-eslint/unbound-method - const { selector, onChange } = props; - - const store = mustExist(useContext(StateContext)); - const { prompt, negativePrompt } = useStore(store, selector, shallow); - - const client = mustExist(useContext(ClientContext)); - const models = useQuery(['models'], async () => client.models(), { - staleTime: STALE_TIME, - }); - const wildcards = useQuery(['wildcards'], async () => client.wildcards(), { - staleTime: STALE_TIME, - }); + const { models, selector, onChange } = props; const { t } = useTranslation(); - - function addNetwork(type: string, name: string, weight = 1.0) { - onChange({ - prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`, - negativePrompt, - }); - } + const store = mustExist(useContext(StateContext)); + const { prompt, negativePrompt } = useStore(store, selector, shallow); function addToken(name: string) { onChange({ @@ -59,16 +47,10 @@ export function PromptInput(props: PromptInputProps) { }); } - function addWildcard(name: string) { - onChange({ - prompt: `${prompt}, __${name}__`, - }); - } - const tokens = useMemo(() => { const networks = extractNetworks(prompt); - return getNetworkTokens(models.data, networks); - }, [prompt, models.data]); + return getNetworkTokens(models, networks); + }, [models, prompt]); return { - props.onChange({ + onChange({ prompt: event.target.value, negativePrompt, }); @@ -100,6 +82,46 @@ export function PromptInput(props: PromptInputProps) { }); }} /> + ; +} + +export function PromptInput(props: PromptInputProps) { + // eslint-disable-next-line @typescript-eslint/unbound-method + const { selector, onChange } = props; + + const store = mustExist(useContext(StateContext)); + const client = mustExist(useContext(ClientContext)); + const models = useQuery(['models'], async () => client.models(), { + staleTime: STALE_TIME, + }); + const wildcards = useQuery(['wildcards'], async () => client.wildcards(), { + staleTime: STALE_TIME, + }); + + const { t } = useTranslation(); + + function addNetwork(type: string, name: string, weight = 1.0) { + const { prompt, negativePrompt } = selector(store.getState()); + onChange({ + negativePrompt, + prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`, + }); + } + + function addWildcard(name: string) { + const { prompt, negativePrompt } = selector(store.getState()); + onChange({ + negativePrompt, + prompt: `${prompt}, __${name}__`, + }); + } + + return + result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), + selector: (result) => filterNetworks(result.networks, 'inversion'), }} onSelect={(name) => { addNetwork('inversion', name); @@ -119,7 +141,7 @@ export function PromptInput(props: PromptInputProps) { name={t('modelType.lora')} query={{ result: models, - selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name), + selector: (result) => filterNetworks(result.networks, 'lora'), }} onSelect={(name) => { addNetwork('lora', name); @@ -141,6 +163,10 @@ export function PromptInput(props: PromptInputProps) { ; } +export function filterNetworks(networks: Array, type: string): Array { + return networks.filter((network) => network.type === type).map((network) => network.name); +} + export const ANY_TOKEN = /<([^>]+)>/g; export type TokenList = Array<[string, number]>; @@ -166,7 +192,7 @@ export function extractNetworks(prompt: string): PromptNetworks { lora.push([name, parseFloat(weight)]); break; default: - // ignore others + // ignore others } }