1
0
Fork 0

feat: show tokens for networks in prompt

This commit is contained in:
Sean Sube 2023-11-12 15:15:06 -06:00
parent 3ffbc00390
commit 44e483322e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 140 additions and 64 deletions

View File

@ -379,6 +379,9 @@ def encode_prompt(
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [
pipe._encode_prompt(
remove_tokens(prompt),

View File

@ -10,34 +10,53 @@ $defs:
- type: number
- type: string
lora_network:
tensor_format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
embedding_network:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
format:
$ref: "#/defs/tensor_format"
label:
type: string
weight:
type: number
textual_inversion_network:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
format:
model:
type: string
enum: [concept, embeddings]
label:
name:
type: string
source:
type: string
token:
type: string
type:
type: string
const: inversion # TODO: add embedding
weight:
type: number
lora_network:
type: object
required: [name, source, type]
properties:
label:
type: string
model:
type: string
enum: [cloneofsimo, sd-scripts]
name:
type: string
source:
type: string
tokens:
type: array
items:
type: string
type:
type: string
const: lora
weight:
type: number
@ -46,8 +65,7 @@ $defs:
required: [name, source]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
$ref: "#/defs/tensor_format"
half:
type: boolean
label:
@ -85,7 +103,7 @@ $defs:
inversions:
type: array
items:
$ref: "#/$defs/textual_inversion_network"
$ref: "#/$defs/embedding_network"
loras:
type: array
items:
@ -142,31 +160,6 @@ $defs:
source:
type: string
source_network:
type: object
required: [name, source, type]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name:
type: string
source:
type: string
type:
type: string
enum: [inversion, lora]
translation:
type: object
additionalProperties: False
@ -194,7 +187,9 @@ properties:
networks:
type: array
items:
$ref: "#/$defs/source_network"
oneOf:
- $ref: "#/$defs/lora_network"
- $ref: "#/$defs/embedding_network"
sources:
type: array
items:

View File

@ -1,5 +1,5 @@
import { mustExist } from '@apextoaster/js-utils';
import { TextField } from '@mui/material';
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
import { Chip, TextField } from '@mui/material';
import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query';
import * as React from 'react';
@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';
const { useContext } = React;
@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
staleTime: STALE_TIME,
});
const tokens = splitPrompt(prompt);
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
const { t } = useTranslation();
const helper = t('input.prompt.tokens', {
groups,
tokens: tokens.length,
});
function addToken(type: string, name: string, weight = 1.0) {
function addNetwork(type: string, name: string, weight = 1.0) {
onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`,
negativePrompt,
});
}
function addToken(name: string) {
onChange({
prompt: `${prompt}, ${name}`,
});
}
const networks = extractNetworks(prompt);
const tokens = getNetworkTokens(models.data, networks);
return <Stack spacing={2}>
<TextField
label={t('parameter.prompt')}
helperText={helper}
variant='outlined'
value={prompt}
onChange={(event) => {
@ -77,6 +79,7 @@ export function PromptInput(props: PromptInputProps) {
});
}}
/>
{tokens.map(([token, _weight]) => <Chip label={token} onClick={() => addToken(token)} />)}
<TextField
label={t('parameter.negativePrompt')}
variant='outlined'
@ -98,7 +101,7 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}}
onSelect={(name) => {
addToken('inversion', name);
addNetwork('inversion', name);
}}
/>
<QueryMenu
@ -110,9 +113,69 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}}
onSelect={(name) => {
addToken('lora', name);
addNetwork('lora', name);
}}
/>
</Stack>
</Stack>;
}
export const ANY_TOKEN = /<([^>])+>/g;
export type TokenList = Array<[string, number]>;
export interface PromptNetworks {
inversion: TokenList;
lora: TokenList;
}
export function extractNetworks(prompt: string): PromptNetworks {
const inversion: TokenList = [];
const lora: TokenList = [];
for (const token of prompt.matchAll(ANY_TOKEN)) {
const [_whole, match] = Array.from(token);
const [type, name, weight, ..._rest] = match.split(':');
switch (type) {
case 'inversion':
inversion.push([name, parseFloat(weight)]);
break;
case 'lora':
lora.push([name, parseFloat(weight)]);
break;
default:
// ignore others
}
}
return {
inversion,
lora,
};
}
// eslint-disable-next-line sonarjs/cognitive-complexity
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
const tokens: TokenList = [];
if (doesExist(models)) {
for (const [name, weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
if (doesExist(model) && model.type === 'inversion') {
tokens.push([model.token, weight]);
}
}
for (const [name, weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') {
for (const token of model.tokens) {
tokens.push([token, weight]);
}
}
}
}
return tokens;
}

View File

@ -39,13 +39,28 @@ export interface ReadyResponse {
ready: boolean;
}
export interface NetworkModel {
export interface ControlNetwork {
name: string;
type: 'control' | 'inversion' | 'lora';
// TODO: add token
// TODO: add layer/token count
type: 'control';
}
export interface EmbeddingNetwork {
label: string;
name: string;
token: string;
type: 'inversion';
// TODO: add layer count
}
export interface LoraNetwork {
name: string;
label: string;
tokens: Array<string>;
type: 'lora';
}
export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;
export interface FilterResponse {
mask: Array<string>;
source: Array<string>;