1
0
Fork 0

fix(gui): dedupe and sort available prompt tokens

This commit is contained in:
Sean Sube 2023-11-12 21:14:13 -06:00
parent 95e2d6d710
commit 35171e6f12
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 11 deletions

View File

@ -12,7 +12,7 @@ import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js'; import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js'; import { ModelResponse } from '../../types/api.js';
const { useContext } = React; const { useContext, useMemo } = React;
/** /**
* @todo replace with a selector * @todo replace with a selector
@ -64,8 +64,10 @@ export function PromptInput(props: PromptInputProps) {
}); });
} }
const networks = extractNetworks(prompt); const tokens = useMemo(() => {
const tokens = getNetworkTokens(models.data, networks); const networks = extractNetworks(prompt);
return getNetworkTokens(models.data, networks);
}, [prompt, models.data]);
return <Stack spacing={2}> return <Stack spacing={2}>
<TextField <TextField
@ -80,7 +82,7 @@ export function PromptInput(props: PromptInputProps) {
}} }}
/> />
<Stack direction='row' spacing={2}> <Stack direction='row' spacing={2}>
{tokens.map(([token, _weight]) => <Chip {tokens.map((token) => <Chip
color={prompt.includes(token) ? 'primary' : 'default'} color={prompt.includes(token) ? 'primary' : 'default'}
label={token} label={token}
onClick={() => addToken(token)} onClick={() => addToken(token)}
@ -162,26 +164,26 @@ export function extractNetworks(prompt: string): PromptNetworks {
} }
// eslint-disable-next-line sonarjs/cognitive-complexity // eslint-disable-next-line sonarjs/cognitive-complexity
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList { export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): Array<string> {
const tokens: TokenList = []; const tokens: Set<string> = new Set();
if (doesExist(models)) { if (doesExist(models)) {
for (const [name, weight] of networks.inversion) { for (const [name, _weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name); const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
if (doesExist(model) && model.type === 'inversion') { if (doesExist(model) && model.type === 'inversion') {
tokens.push([model.token, weight]); tokens.add(model.token);
} }
} }
for (const [name, weight] of networks.lora) { for (const [name, _weight] of networks.lora) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name); const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') { if (doesExist(model) && model.type === 'lora') {
for (const token of mustDefault(model.tokens, [])) { for (const token of mustDefault(model.tokens, [])) {
tokens.push([token, weight]); tokens.add(token);
} }
} }
} }
} }
return tokens; return Array.from(tokens).sort();
} }