fix(gui): make prompt input perform better with large LoRA/wildcard lists
This commit is contained in:
parent
a65e0fd602
commit
e0929ba870
|
@ -10,7 +10,7 @@ import { shallow } from 'zustand/shallow';
|
|||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { QueryMenu } from '../input/QueryMenu.js';
|
||||
import { ModelResponse } from '../../types/api.js';
|
||||
import { ModelResponse, NetworkModel } from '../../types/api.js';
|
||||
|
||||
const { useContext, useMemo } = React;
|
||||
|
||||
|
@ -27,31 +27,19 @@ export interface PromptInputProps {
|
|||
onChange(value: PromptValue): void;
|
||||
}
|
||||
|
||||
export interface PromptTextBlockProps extends PromptInputProps {
|
||||
models: Maybe<ModelResponse>;
|
||||
}
|
||||
|
||||
export const PROMPT_GROUP = 75;
|
||||
|
||||
export function PromptInput(props: PromptInputProps) {
|
||||
export function PromptTextBlock(props: PromptTextBlockProps) {
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const { selector, onChange } = props;
|
||||
|
||||
const store = mustExist(useContext(StateContext));
|
||||
const { prompt, negativePrompt } = useStore(store, selector, shallow);
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const models = useQuery(['models'], async () => client.models(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const wildcards = useQuery(['wildcards'], async () => client.wildcards(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const { models, selector, onChange } = props;
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
function addNetwork(type: string, name: string, weight = 1.0) {
|
||||
onChange({
|
||||
prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`,
|
||||
negativePrompt,
|
||||
});
|
||||
}
|
||||
const store = mustExist(useContext(StateContext));
|
||||
const { prompt, negativePrompt } = useStore(store, selector, shallow);
|
||||
|
||||
function addToken(name: string) {
|
||||
onChange({
|
||||
|
@ -59,16 +47,10 @@ export function PromptInput(props: PromptInputProps) {
|
|||
});
|
||||
}
|
||||
|
||||
function addWildcard(name: string) {
|
||||
onChange({
|
||||
prompt: `${prompt}, __${name}__`,
|
||||
});
|
||||
}
|
||||
|
||||
const tokens = useMemo(() => {
|
||||
const networks = extractNetworks(prompt);
|
||||
return getNetworkTokens(models.data, networks);
|
||||
}, [prompt, models.data]);
|
||||
return getNetworkTokens(models, networks);
|
||||
}, [models, prompt]);
|
||||
|
||||
return <Stack spacing={2}>
|
||||
<TextField
|
||||
|
@ -76,7 +58,7 @@ export function PromptInput(props: PromptInputProps) {
|
|||
variant='outlined'
|
||||
value={prompt}
|
||||
onChange={(event) => {
|
||||
props.onChange({
|
||||
onChange({
|
||||
prompt: event.target.value,
|
||||
negativePrompt,
|
||||
});
|
||||
|
@ -100,6 +82,46 @@ export function PromptInput(props: PromptInputProps) {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
</Stack>;
|
||||
}
|
||||
|
||||
export function PromptInput(props: PromptInputProps) {
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const { selector, onChange } = props;
|
||||
|
||||
const store = mustExist(useContext(StateContext));
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const models = useQuery(['models'], async () => client.models(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
const wildcards = useQuery(['wildcards'], async () => client.wildcards(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
function addNetwork(type: string, name: string, weight = 1.0) {
|
||||
const { prompt, negativePrompt } = selector(store.getState());
|
||||
onChange({
|
||||
negativePrompt,
|
||||
prompt: `<${type}:${name}:${weight.toFixed(2)}> ${prompt}`,
|
||||
});
|
||||
}
|
||||
|
||||
function addWildcard(name: string) {
|
||||
const { prompt, negativePrompt } = selector(store.getState());
|
||||
onChange({
|
||||
negativePrompt,
|
||||
prompt: `${prompt}, __${name}__`,
|
||||
});
|
||||
}
|
||||
|
||||
return <Stack spacing={2}>
|
||||
<PromptTextBlock
|
||||
models={models.data}
|
||||
onChange={onChange}
|
||||
selector={selector}
|
||||
/>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<QueryMenu
|
||||
id='inversion'
|
||||
|
@ -107,7 +129,7 @@ export function PromptInput(props: PromptInputProps) {
|
|||
name={t('modelType.inversion')}
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||
selector: (result) => filterNetworks(result.networks, 'inversion'),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
addNetwork('inversion', name);
|
||||
|
@ -119,7 +141,7 @@ export function PromptInput(props: PromptInputProps) {
|
|||
name={t('modelType.lora')}
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||
selector: (result) => filterNetworks(result.networks, 'lora'),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
addNetwork('lora', name);
|
||||
|
@ -141,6 +163,10 @@ export function PromptInput(props: PromptInputProps) {
|
|||
</Stack>;
|
||||
}
|
||||
|
||||
export function filterNetworks(networks: Array<NetworkModel>, type: string): Array<string> {
|
||||
return networks.filter((network) => network.type === type).map((network) => network.name);
|
||||
}
|
||||
|
||||
export const ANY_TOKEN = /<([^>]+)>/g;
|
||||
|
||||
export type TokenList = Array<[string, number]>;
|
||||
|
|
Loading…
Reference in New Issue