1
0
Fork 0

feat: show tokens for networks in prompt

This commit is contained in:
Sean Sube 2023-11-12 15:15:06 -06:00
parent 3ffbc00390
commit 44e483322e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 140 additions and 64 deletions

View File

@ -379,6 +379,9 @@ def encode_prompt(
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]: ) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [ return [
pipe._encode_prompt( pipe._encode_prompt(
remove_tokens(prompt), remove_tokens(prompt),

View File

@ -10,34 +10,53 @@ $defs:
- type: number - type: number
- type: string - type: string
lora_network: tensor_format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
embedding_network:
type: object type: object
required: [name, source] required: [name, source]
properties: properties:
name: format:
type: string $ref: "#/defs/tensor_format"
source:
type: string
label: label:
type: string type: string
weight: model:
type: number
textual_inversion_network:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
format:
type: string type: string
enum: [concept, embeddings] enum: [concept, embeddings]
label: name:
type: string
source:
type: string type: string
token: token:
type: string type: string
type:
type: string
const: inversion # TODO: add embedding
weight:
type: number
lora_network:
type: object
required: [name, source, type]
properties:
label:
type: string
model:
type: string
enum: [cloneofsimo, sd-scripts]
name:
type: string
source:
type: string
tokens:
type: array
items:
type: string
type:
type: string
const: lora
weight: weight:
type: number type: number
@ -46,8 +65,7 @@ $defs:
required: [name, source] required: [name, source]
properties: properties:
format: format:
type: string $ref: "#/defs/tensor_format"
enum: [bin, ckpt, onnx, pt, pth, safetensors]
half: half:
type: boolean type: boolean
label: label:
@ -85,7 +103,7 @@ $defs:
inversions: inversions:
type: array type: array
items: items:
$ref: "#/$defs/textual_inversion_network" $ref: "#/$defs/embedding_network"
loras: loras:
type: array type: array
items: items:
@ -142,31 +160,6 @@ $defs:
source: source:
type: string type: string
source_network:
type: object
required: [name, source, type]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name:
type: string
source:
type: string
type:
type: string
enum: [inversion, lora]
translation: translation:
type: object type: object
additionalProperties: False additionalProperties: False
@ -194,7 +187,9 @@ properties:
networks: networks:
type: array type: array
items: items:
$ref: "#/$defs/source_network" oneOf:
- $ref: "#/$defs/lora_network"
- $ref: "#/$defs/embedding_network"
sources: sources:
type: array type: array
items: items:

View File

@ -1,5 +1,5 @@
import { mustExist } from '@apextoaster/js-utils'; import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
import { TextField } from '@mui/material'; import { Chip, TextField } from '@mui/material';
import { Stack } from '@mui/system'; import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import * as React from 'react'; import * as React from 'react';
@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js'; import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js'; import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';
const { useContext } = React; const { useContext } = React;
@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
const tokens = splitPrompt(prompt);
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
const { t } = useTranslation(); const { t } = useTranslation();
const helper = t('input.prompt.tokens', {
groups,
tokens: tokens.length,
});
function addToken(type: string, name: string, weight = 1.0) { function addNetwork(type: string, name: string, weight = 1.0) {
onChange({ onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`, prompt: `<${type}:${name}:1.0> ${prompt}`,
negativePrompt, negativePrompt,
}); });
} }
function addToken(name: string) {
onChange({
prompt: `${prompt}, ${name}`,
});
}
const networks = extractNetworks(prompt);
const tokens = getNetworkTokens(models.data, networks);
return <Stack spacing={2}> return <Stack spacing={2}>
<TextField <TextField
label={t('parameter.prompt')} label={t('parameter.prompt')}
helperText={helper}
variant='outlined' variant='outlined'
value={prompt} value={prompt}
onChange={(event) => { onChange={(event) => {
@ -77,6 +79,7 @@ export function PromptInput(props: PromptInputProps) {
}); });
}} }}
/> />
{tokens.map(([token, _weight]) => <Chip label={token} onClick={() => addToken(token)} />)}
<TextField <TextField
label={t('parameter.negativePrompt')} label={t('parameter.negativePrompt')}
variant='outlined' variant='outlined'
@ -98,7 +101,7 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}} }}
onSelect={(name) => { onSelect={(name) => {
addToken('inversion', name); addNetwork('inversion', name);
}} }}
/> />
<QueryMenu <QueryMenu
@ -110,9 +113,69 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name), selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}} }}
onSelect={(name) => { onSelect={(name) => {
addToken('lora', name); addNetwork('lora', name);
}} }}
/> />
</Stack> </Stack>
</Stack>; </Stack>;
} }
export const ANY_TOKEN = /<([^>])+>/g;
export type TokenList = Array<[string, number]>;
export interface PromptNetworks {
inversion: TokenList;
lora: TokenList;
}
export function extractNetworks(prompt: string): PromptNetworks {
const inversion: TokenList = [];
const lora: TokenList = [];
for (const token of prompt.matchAll(ANY_TOKEN)) {
const [_whole, match] = Array.from(token);
const [type, name, weight, ..._rest] = match.split(':');
switch (type) {
case 'inversion':
inversion.push([name, parseFloat(weight)]);
break;
case 'lora':
lora.push([name, parseFloat(weight)]);
break;
default:
// ignore others
}
}
return {
inversion,
lora,
};
}
// eslint-disable-next-line sonarjs/cognitive-complexity
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
const tokens: TokenList = [];
if (doesExist(models)) {
for (const [name, weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
if (doesExist(model) && model.type === 'inversion') {
tokens.push([model.token, weight]);
}
}
for (const [name, weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') {
for (const token of model.tokens) {
tokens.push([token, weight]);
}
}
}
}
return tokens;
}

View File

@ -39,13 +39,28 @@ export interface ReadyResponse {
ready: boolean; ready: boolean;
} }
export interface NetworkModel { export interface ControlNetwork {
name: string; name: string;
type: 'control' | 'inversion' | 'lora'; type: 'control';
// TODO: add token
// TODO: add layer/token count
} }
export interface EmbeddingNetwork {
label: string;
name: string;
token: string;
type: 'inversion';
// TODO: add layer count
}
export interface LoraNetwork {
name: string;
label: string;
tokens: Array<string>;
type: 'lora';
}
export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;
export interface FilterResponse { export interface FilterResponse {
mask: Array<string>; mask: Array<string>;
source: Array<string>; source: Array<string>;