feat: show tokens for networks in prompt
This commit is contained in:
parent
3ffbc00390
commit
44e483322e
|
@ -379,6 +379,9 @@ def encode_prompt(
|
||||||
num_images_per_prompt: int = 1,
|
num_images_per_prompt: int = 1,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
TODO: does not work with SDXL, fix or turn into a pipeline patch
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
pipe._encode_prompt(
|
pipe._encode_prompt(
|
||||||
remove_tokens(prompt),
|
remove_tokens(prompt),
|
||||||
|
|
|
@ -10,34 +10,53 @@ $defs:
|
||||||
- type: number
|
- type: number
|
||||||
- type: string
|
- type: string
|
||||||
|
|
||||||
lora_network:
|
tensor_format:
|
||||||
|
type: string
|
||||||
|
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||||
|
|
||||||
|
embedding_network:
|
||||||
type: object
|
type: object
|
||||||
required: [name, source]
|
required: [name, source]
|
||||||
properties:
|
properties:
|
||||||
name:
|
format:
|
||||||
type: string
|
$ref: "#/defs/tensor_format"
|
||||||
source:
|
|
||||||
type: string
|
|
||||||
label:
|
label:
|
||||||
type: string
|
type: string
|
||||||
weight:
|
model:
|
||||||
type: number
|
|
||||||
|
|
||||||
textual_inversion_network:
|
|
||||||
type: object
|
|
||||||
required: [name, source]
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
source:
|
|
||||||
type: string
|
|
||||||
format:
|
|
||||||
type: string
|
type: string
|
||||||
enum: [concept, embeddings]
|
enum: [concept, embeddings]
|
||||||
label:
|
name:
|
||||||
|
type: string
|
||||||
|
source:
|
||||||
type: string
|
type: string
|
||||||
token:
|
token:
|
||||||
type: string
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: inversion # TODO: add embedding
|
||||||
|
weight:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
lora_network:
|
||||||
|
type: object
|
||||||
|
required: [name, source, type]
|
||||||
|
properties:
|
||||||
|
label:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
enum: [cloneofsimo, sd-scripts]
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
source:
|
||||||
|
type: string
|
||||||
|
tokens:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: lora
|
||||||
weight:
|
weight:
|
||||||
type: number
|
type: number
|
||||||
|
|
||||||
|
@ -46,8 +65,7 @@ $defs:
|
||||||
required: [name, source]
|
required: [name, source]
|
||||||
properties:
|
properties:
|
||||||
format:
|
format:
|
||||||
type: string
|
$ref: "#/defs/tensor_format"
|
||||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
|
||||||
half:
|
half:
|
||||||
type: boolean
|
type: boolean
|
||||||
label:
|
label:
|
||||||
|
@ -85,7 +103,7 @@ $defs:
|
||||||
inversions:
|
inversions:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: "#/$defs/textual_inversion_network"
|
$ref: "#/$defs/embedding_network"
|
||||||
loras:
|
loras:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
@ -142,31 +160,6 @@ $defs:
|
||||||
source:
|
source:
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
source_network:
|
|
||||||
type: object
|
|
||||||
required: [name, source, type]
|
|
||||||
properties:
|
|
||||||
format:
|
|
||||||
type: string
|
|
||||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
|
||||||
model:
|
|
||||||
type: string
|
|
||||||
enum: [
|
|
||||||
# inversion
|
|
||||||
concept,
|
|
||||||
embeddings,
|
|
||||||
# lora
|
|
||||||
cloneofsimo,
|
|
||||||
sd-scripts
|
|
||||||
]
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
source:
|
|
||||||
type: string
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
enum: [inversion, lora]
|
|
||||||
|
|
||||||
translation:
|
translation:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties: False
|
additionalProperties: False
|
||||||
|
@ -194,7 +187,9 @@ properties:
|
||||||
networks:
|
networks:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: "#/$defs/source_network"
|
oneOf:
|
||||||
|
- $ref: "#/$defs/lora_network"
|
||||||
|
- $ref: "#/$defs/embedding_network"
|
||||||
sources:
|
sources:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
|
||||||
import { TextField } from '@mui/material';
|
import { Chip, TextField } from '@mui/material';
|
||||||
import { Stack } from '@mui/system';
|
import { Stack } from '@mui/system';
|
||||||
import { useQuery } from '@tanstack/react-query';
|
import { useQuery } from '@tanstack/react-query';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
|
@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
|
||||||
import { STALE_TIME } from '../../config.js';
|
import { STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
||||||
import { QueryMenu } from '../input/QueryMenu.js';
|
import { QueryMenu } from '../input/QueryMenu.js';
|
||||||
|
import { ModelResponse } from '../../types/api.js';
|
||||||
|
|
||||||
const { useContext } = React;
|
const { useContext } = React;
|
||||||
|
|
||||||
|
@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
|
||||||
staleTime: STALE_TIME,
|
staleTime: STALE_TIME,
|
||||||
});
|
});
|
||||||
|
|
||||||
const tokens = splitPrompt(prompt);
|
|
||||||
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const helper = t('input.prompt.tokens', {
|
|
||||||
groups,
|
|
||||||
tokens: tokens.length,
|
|
||||||
});
|
|
||||||
|
|
||||||
function addToken(type: string, name: string, weight = 1.0) {
|
function addNetwork(type: string, name: string, weight = 1.0) {
|
||||||
onChange({
|
onChange({
|
||||||
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function addToken(name: string) {
|
||||||
|
onChange({
|
||||||
|
prompt: `${prompt}, ${name}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const networks = extractNetworks(prompt);
|
||||||
|
const tokens = getNetworkTokens(models.data, networks);
|
||||||
|
|
||||||
return <Stack spacing={2}>
|
return <Stack spacing={2}>
|
||||||
<TextField
|
<TextField
|
||||||
label={t('parameter.prompt')}
|
label={t('parameter.prompt')}
|
||||||
helperText={helper}
|
|
||||||
variant='outlined'
|
variant='outlined'
|
||||||
value={prompt}
|
value={prompt}
|
||||||
onChange={(event) => {
|
onChange={(event) => {
|
||||||
|
@ -77,6 +79,7 @@ export function PromptInput(props: PromptInputProps) {
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
{tokens.map(([token, _weight]) => <Chip label={token} onClick={() => addToken(token)} />)}
|
||||||
<TextField
|
<TextField
|
||||||
label={t('parameter.negativePrompt')}
|
label={t('parameter.negativePrompt')}
|
||||||
variant='outlined'
|
variant='outlined'
|
||||||
|
@ -98,7 +101,7 @@ export function PromptInput(props: PromptInputProps) {
|
||||||
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||||
}}
|
}}
|
||||||
onSelect={(name) => {
|
onSelect={(name) => {
|
||||||
addToken('inversion', name);
|
addNetwork('inversion', name);
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<QueryMenu
|
<QueryMenu
|
||||||
|
@ -110,9 +113,69 @@ export function PromptInput(props: PromptInputProps) {
|
||||||
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||||
}}
|
}}
|
||||||
onSelect={(name) => {
|
onSelect={(name) => {
|
||||||
addToken('lora', name);
|
addNetwork('lora', name);
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Stack>
|
</Stack>
|
||||||
</Stack>;
|
</Stack>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const ANY_TOKEN = /<([^>])+>/g;
|
||||||
|
|
||||||
|
export type TokenList = Array<[string, number]>;
|
||||||
|
|
||||||
|
export interface PromptNetworks {
|
||||||
|
inversion: TokenList;
|
||||||
|
lora: TokenList;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function extractNetworks(prompt: string): PromptNetworks {
|
||||||
|
const inversion: TokenList = [];
|
||||||
|
const lora: TokenList = [];
|
||||||
|
|
||||||
|
for (const token of prompt.matchAll(ANY_TOKEN)) {
|
||||||
|
const [_whole, match] = Array.from(token);
|
||||||
|
const [type, name, weight, ..._rest] = match.split(':');
|
||||||
|
|
||||||
|
switch (type) {
|
||||||
|
case 'inversion':
|
||||||
|
inversion.push([name, parseFloat(weight)]);
|
||||||
|
break;
|
||||||
|
case 'lora':
|
||||||
|
lora.push([name, parseFloat(weight)]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// ignore others
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
inversion,
|
||||||
|
lora,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||||
|
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
|
||||||
|
const tokens: TokenList = [];
|
||||||
|
|
||||||
|
if (doesExist(models)) {
|
||||||
|
for (const [name, weight] of networks.inversion) {
|
||||||
|
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
|
||||||
|
if (doesExist(model) && model.type === 'inversion') {
|
||||||
|
tokens.push([model.token, weight]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [name, weight] of networks.inversion) {
|
||||||
|
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
||||||
|
if (doesExist(model) && model.type === 'lora') {
|
||||||
|
for (const token of model.tokens) {
|
||||||
|
tokens.push([token, weight]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
|
@ -39,13 +39,28 @@ export interface ReadyResponse {
|
||||||
ready: boolean;
|
ready: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface NetworkModel {
|
export interface ControlNetwork {
|
||||||
name: string;
|
name: string;
|
||||||
type: 'control' | 'inversion' | 'lora';
|
type: 'control';
|
||||||
// TODO: add token
|
|
||||||
// TODO: add layer/token count
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface EmbeddingNetwork {
|
||||||
|
label: string;
|
||||||
|
name: string;
|
||||||
|
token: string;
|
||||||
|
type: 'inversion';
|
||||||
|
// TODO: add layer count
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface LoraNetwork {
|
||||||
|
name: string;
|
||||||
|
label: string;
|
||||||
|
tokens: Array<string>;
|
||||||
|
type: 'lora';
|
||||||
|
}
|
||||||
|
|
||||||
|
export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;
|
||||||
|
|
||||||
export interface FilterResponse {
|
export interface FilterResponse {
|
||||||
mask: Array<string>;
|
mask: Array<string>;
|
||||||
source: Array<string>;
|
source: Array<string>;
|
||||||
|
|
Loading…
Reference in New Issue