1
0
Fork 0

fix(gui): dedupe and sort available prompt tokens

This commit is contained in:
Sean Sube 2023-11-12 21:14:13 -06:00
parent 95e2d6d710
commit 35171e6f12
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 11 deletions

View File

@ -12,7 +12,7 @@ import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';
const { useContext } = React;
const { useContext, useMemo } = React;
/**
* @todo replace with a selector
@ -64,8 +64,10 @@ export function PromptInput(props: PromptInputProps) {
});
}
const tokens = useMemo(() => {
const networks = extractNetworks(prompt);
const tokens = getNetworkTokens(models.data, networks);
return getNetworkTokens(models.data, networks);
}, [prompt, models.data]);
return <Stack spacing={2}>
<TextField
@ -80,7 +82,7 @@ export function PromptInput(props: PromptInputProps) {
}}
/>
<Stack direction='row' spacing={2}>
{tokens.map(([token, _weight]) => <Chip
{tokens.map((token) => <Chip
color={prompt.includes(token) ? 'primary' : 'default'}
label={token}
onClick={() => addToken(token)}
@ -162,26 +164,26 @@ export function extractNetworks(prompt: string): PromptNetworks {
}
// eslint-disable-next-line sonarjs/cognitive-complexity
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
const tokens: TokenList = [];
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): Array<string> {
const tokens: Set<string> = new Set();
if (doesExist(models)) {
for (const [name, weight] of networks.inversion) {
for (const [name, _weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
if (doesExist(model) && model.type === 'inversion') {
tokens.push([model.token, weight]);
tokens.add(model.token);
}
}
for (const [name, weight] of networks.lora) {
for (const [name, _weight] of networks.lora) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') {
for (const token of mustDefault(model.tokens, [])) {
tokens.push([token, weight]);
tokens.add(token);
}
}
}
}
return tokens;
return Array.from(tokens).sort();
}