fix(gui): dedupe and sort available prompt tokens
This commit is contained in:
parent
95e2d6d710
commit
35171e6f12
|
@ -12,7 +12,7 @@ import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
||||||
import { QueryMenu } from '../input/QueryMenu.js';
|
import { QueryMenu } from '../input/QueryMenu.js';
|
||||||
import { ModelResponse } from '../../types/api.js';
|
import { ModelResponse } from '../../types/api.js';
|
||||||
|
|
||||||
const { useContext } = React;
|
const { useContext, useMemo } = React;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @todo replace with a selector
|
* @todo replace with a selector
|
||||||
|
@ -64,8 +64,10 @@ export function PromptInput(props: PromptInputProps) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tokens = useMemo(() => {
|
||||||
const networks = extractNetworks(prompt);
|
const networks = extractNetworks(prompt);
|
||||||
const tokens = getNetworkTokens(models.data, networks);
|
return getNetworkTokens(models.data, networks);
|
||||||
|
}, [prompt, models.data]);
|
||||||
|
|
||||||
return <Stack spacing={2}>
|
return <Stack spacing={2}>
|
||||||
<TextField
|
<TextField
|
||||||
|
@ -80,7 +82,7 @@ export function PromptInput(props: PromptInputProps) {
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<Stack direction='row' spacing={2}>
|
<Stack direction='row' spacing={2}>
|
||||||
{tokens.map(([token, _weight]) => <Chip
|
{tokens.map((token) => <Chip
|
||||||
color={prompt.includes(token) ? 'primary' : 'default'}
|
color={prompt.includes(token) ? 'primary' : 'default'}
|
||||||
label={token}
|
label={token}
|
||||||
onClick={() => addToken(token)}
|
onClick={() => addToken(token)}
|
||||||
|
@ -162,26 +164,26 @@ export function extractNetworks(prompt: string): PromptNetworks {
|
||||||
}
|
}
|
||||||
|
|
||||||
// eslint-disable-next-line sonarjs/cognitive-complexity
|
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||||
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
|
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): Array<string> {
|
||||||
const tokens: TokenList = [];
|
const tokens: Set<string> = new Set();
|
||||||
|
|
||||||
if (doesExist(models)) {
|
if (doesExist(models)) {
|
||||||
for (const [name, weight] of networks.inversion) {
|
for (const [name, _weight] of networks.inversion) {
|
||||||
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
|
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
|
||||||
if (doesExist(model) && model.type === 'inversion') {
|
if (doesExist(model) && model.type === 'inversion') {
|
||||||
tokens.push([model.token, weight]);
|
tokens.add(model.token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const [name, weight] of networks.lora) {
|
for (const [name, _weight] of networks.lora) {
|
||||||
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
||||||
if (doesExist(model) && model.type === 'lora') {
|
if (doesExist(model) && model.type === 'lora') {
|
||||||
for (const token of mustDefault(model.tokens, [])) {
|
for (const token of mustDefault(model.tokens, [])) {
|
||||||
tokens.push([token, weight]);
|
tokens.add(token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokens;
|
return Array.from(tokens).sort();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue