1
0
Fork 0

fix(gui): make prompt input perform better with large LoRA/wildcard lists

This commit is contained in:
Sean Sube 2023-12-16 23:19:10 -06:00
parent a65e0fd602
commit e0929ba870
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 59 additions and 33 deletions

View File

@ -10,7 +10,7 @@ import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';
import { ModelResponse, NetworkModel } from '../../types/api.js';
const { useContext, useMemo } = React;
@ -27,31 +27,19 @@ export interface PromptInputProps {
onChange(value: PromptValue): void;
}
export interface PromptTextBlockProps extends PromptInputProps {
models: Maybe<ModelResponse>;
}
export const PROMPT_GROUP = 75;
export function PromptInput(props: PromptInputProps) {
export function PromptTextBlock(props: PromptTextBlockProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selector, onChange } = props;
const store = mustExist(useContext(StateContext));
const { prompt, negativePrompt } = useStore(store, selector, shallow);
const client = mustExist(useContext(ClientContext));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const wildcards = useQuery(['wildcards'], async () => client.wildcards(), {
staleTime: STALE_TIME,
});
const { models, selector, onChange } = props;
const { t } = useTranslation();
function addNetwork(type: string, name: string, weight = 1.0) {
onChange({
prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`,
negativePrompt,
});
}
const store = mustExist(useContext(StateContext));
const { prompt, negativePrompt } = useStore(store, selector, shallow);
function addToken(name: string) {
onChange({
@ -59,16 +47,10 @@ export function PromptInput(props: PromptInputProps) {
});
}
function addWildcard(name: string) {
onChange({
prompt: `${prompt}, __${name}__`,
});
}
const tokens = useMemo(() => {
const networks = extractNetworks(prompt);
return getNetworkTokens(models.data, networks);
}, [prompt, models.data]);
return getNetworkTokens(models, networks);
}, [models, prompt]);
return <Stack spacing={2}>
<TextField
@ -76,7 +58,7 @@ export function PromptInput(props: PromptInputProps) {
variant='outlined'
value={prompt}
onChange={(event) => {
props.onChange({
onChange({
prompt: event.target.value,
negativePrompt,
});
@ -100,6 +82,46 @@ export function PromptInput(props: PromptInputProps) {
});
}}
/>
</Stack>;
}
export function PromptInput(props: PromptInputProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selector, onChange } = props;
const store = mustExist(useContext(StateContext));
const client = mustExist(useContext(ClientContext));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const wildcards = useQuery(['wildcards'], async () => client.wildcards(), {
staleTime: STALE_TIME,
});
const { t } = useTranslation();
function addNetwork(type: string, name: string, weight = 1.0) {
const { prompt, negativePrompt } = selector(store.getState());
onChange({
negativePrompt,
prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`,
});
}
function addWildcard(name: string) {
const { prompt, negativePrompt } = selector(store.getState());
onChange({
negativePrompt,
prompt: `${prompt}, __${name}__`,
});
}
return <Stack spacing={2}>
<PromptTextBlock
models={models.data}
onChange={onChange}
selector={selector}
/>
<Stack direction='row' spacing={2}>
<QueryMenu
id='inversion'
@ -107,7 +129,7 @@ export function PromptInput(props: PromptInputProps) {
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
selector: (result) => filterNetworks(result.networks, 'inversion'),
}}
onSelect={(name) => {
addNetwork('inversion', name);
@ -119,7 +141,7 @@ export function PromptInput(props: PromptInputProps) {
name={t('modelType.lora')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
selector: (result) => filterNetworks(result.networks, 'lora'),
}}
onSelect={(name) => {
addNetwork('lora', name);
@ -141,6 +163,10 @@ export function PromptInput(props: PromptInputProps) {
</Stack>;
}
export function filterNetworks(networks: Array<NetworkModel>, type: string): Array<string> {
return networks.filter((network) => network.type === type).map((network) => network.name);
}
export const ANY_TOKEN = /<([^>]+)>/g;
export type TokenList = Array<[string, number]>;