feat: show tokens for networks in prompt
This commit is contained in:
parent
3ffbc00390
commit
44e483322e
|
@ -379,6 +379,9 @@ def encode_prompt(
|
|||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
TODO: does not work with SDXL, fix or turn into a pipeline patch
|
||||
"""
|
||||
return [
|
||||
pipe._encode_prompt(
|
||||
remove_tokens(prompt),
|
||||
|
|
|
@ -10,34 +10,53 @@ $defs:
|
|||
- type: number
|
||||
- type: string
|
||||
|
||||
lora_network:
|
||||
tensor_format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
|
||||
embedding_network:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
format:
|
||||
$ref: "#/defs/tensor_format"
|
||||
label:
|
||||
type: string
|
||||
weight:
|
||||
type: number
|
||||
|
||||
textual_inversion_network:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
format:
|
||||
model:
|
||||
type: string
|
||||
enum: [concept, embeddings]
|
||||
label:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
token:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: inversion # TODO: add embedding
|
||||
weight:
|
||||
type: number
|
||||
|
||||
lora_network:
|
||||
type: object
|
||||
required: [name, source, type]
|
||||
properties:
|
||||
label:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
enum: [cloneofsimo, sd-scripts]
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
tokens:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: lora
|
||||
weight:
|
||||
type: number
|
||||
|
||||
|
@ -46,8 +65,7 @@ $defs:
|
|||
required: [name, source]
|
||||
properties:
|
||||
format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
$ref: "#/defs/tensor_format"
|
||||
half:
|
||||
type: boolean
|
||||
label:
|
||||
|
@ -85,7 +103,7 @@ $defs:
|
|||
inversions:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/textual_inversion_network"
|
||||
$ref: "#/$defs/embedding_network"
|
||||
loras:
|
||||
type: array
|
||||
items:
|
||||
|
@ -142,31 +160,6 @@ $defs:
|
|||
source:
|
||||
type: string
|
||||
|
||||
source_network:
|
||||
type: object
|
||||
required: [name, source, type]
|
||||
properties:
|
||||
format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
model:
|
||||
type: string
|
||||
enum: [
|
||||
# inversion
|
||||
concept,
|
||||
embeddings,
|
||||
# lora
|
||||
cloneofsimo,
|
||||
sd-scripts
|
||||
]
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
enum: [inversion, lora]
|
||||
|
||||
translation:
|
||||
type: object
|
||||
additionalProperties: False
|
||||
|
@ -194,7 +187,9 @@ properties:
|
|||
networks:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/source_network"
|
||||
oneOf:
|
||||
- $ref: "#/$defs/lora_network"
|
||||
- $ref: "#/$defs/embedding_network"
|
||||
sources:
|
||||
type: array
|
||||
items:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { TextField } from '@mui/material';
|
||||
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { Chip, TextField } from '@mui/material';
|
||||
import { Stack } from '@mui/system';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import * as React from 'react';
|
||||
|
@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
|
|||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { QueryMenu } from '../input/QueryMenu.js';
|
||||
import { ModelResponse } from '../../types/api.js';
|
||||
|
||||
const { useContext } = React;
|
||||
|
||||
|
@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
|
|||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
const tokens = splitPrompt(prompt);
|
||||
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const helper = t('input.prompt.tokens', {
|
||||
groups,
|
||||
tokens: tokens.length,
|
||||
});
|
||||
|
||||
function addToken(type: string, name: string, weight = 1.0) {
|
||||
function addNetwork(type: string, name: string, weight = 1.0) {
|
||||
onChange({
|
||||
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
||||
negativePrompt,
|
||||
});
|
||||
}
|
||||
|
||||
function addToken(name: string) {
|
||||
onChange({
|
||||
prompt: `${prompt}, ${name}`,
|
||||
});
|
||||
}
|
||||
|
||||
const networks = extractNetworks(prompt);
|
||||
const tokens = getNetworkTokens(models.data, networks);
|
||||
|
||||
return <Stack spacing={2}>
|
||||
<TextField
|
||||
label={t('parameter.prompt')}
|
||||
helperText={helper}
|
||||
variant='outlined'
|
||||
value={prompt}
|
||||
onChange={(event) => {
|
||||
|
@ -77,6 +79,7 @@ export function PromptInput(props: PromptInputProps) {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
{tokens.map(([token, _weight]) => <Chip label={token} onClick={() => addToken(token)} />)}
|
||||
<TextField
|
||||
label={t('parameter.negativePrompt')}
|
||||
variant='outlined'
|
||||
|
@ -98,7 +101,7 @@ export function PromptInput(props: PromptInputProps) {
|
|||
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
addToken('inversion', name);
|
||||
addNetwork('inversion', name);
|
||||
}}
|
||||
/>
|
||||
<QueryMenu
|
||||
|
@ -110,9 +113,69 @@ export function PromptInput(props: PromptInputProps) {
|
|||
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
addToken('lora', name);
|
||||
addNetwork('lora', name);
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
</Stack>;
|
||||
}
|
||||
|
||||
export const ANY_TOKEN = /<([^>])+>/g;
|
||||
|
||||
export type TokenList = Array<[string, number]>;
|
||||
|
||||
export interface PromptNetworks {
|
||||
inversion: TokenList;
|
||||
lora: TokenList;
|
||||
}
|
||||
|
||||
export function extractNetworks(prompt: string): PromptNetworks {
|
||||
const inversion: TokenList = [];
|
||||
const lora: TokenList = [];
|
||||
|
||||
for (const token of prompt.matchAll(ANY_TOKEN)) {
|
||||
const [_whole, match] = Array.from(token);
|
||||
const [type, name, weight, ..._rest] = match.split(':');
|
||||
|
||||
switch (type) {
|
||||
case 'inversion':
|
||||
inversion.push([name, parseFloat(weight)]);
|
||||
break;
|
||||
case 'lora':
|
||||
lora.push([name, parseFloat(weight)]);
|
||||
break;
|
||||
default:
|
||||
// ignore others
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
inversion,
|
||||
lora,
|
||||
};
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
|
||||
const tokens: TokenList = [];
|
||||
|
||||
if (doesExist(models)) {
|
||||
for (const [name, weight] of networks.inversion) {
|
||||
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
|
||||
if (doesExist(model) && model.type === 'inversion') {
|
||||
tokens.push([model.token, weight]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const [name, weight] of networks.inversion) {
|
||||
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
||||
if (doesExist(model) && model.type === 'lora') {
|
||||
for (const token of model.tokens) {
|
||||
tokens.push([token, weight]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
|
|
@ -39,13 +39,28 @@ export interface ReadyResponse {
|
|||
ready: boolean;
|
||||
}
|
||||
|
||||
export interface NetworkModel {
|
||||
export interface ControlNetwork {
|
||||
name: string;
|
||||
type: 'control' | 'inversion' | 'lora';
|
||||
// TODO: add token
|
||||
// TODO: add layer/token count
|
||||
type: 'control';
|
||||
}
|
||||
|
||||
export interface EmbeddingNetwork {
|
||||
label: string;
|
||||
name: string;
|
||||
token: string;
|
||||
type: 'inversion';
|
||||
// TODO: add layer count
|
||||
}
|
||||
|
||||
export interface LoraNetwork {
|
||||
name: string;
|
||||
label: string;
|
||||
tokens: Array<string>;
|
||||
type: 'lora';
|
||||
}
|
||||
|
||||
export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;
|
||||
|
||||
export interface FilterResponse {
|
||||
mask: Array<string>;
|
||||
source: Array<string>;
|
||||
|
|
Loading…
Reference in New Issue