Merge branch 'feat/00-prompt-tokens'
This commit is contained in:
commit
e653560f03
|
@ -17,7 +17,7 @@ with a CPU fallback capable of running on laptop-class machines.
|
|||
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
|
||||
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
|
||||
|
||||
![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png)
|
||||
![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png)
|
||||
|
||||
## Features
|
||||
|
||||
|
|
|
@ -115,7 +115,12 @@ def make_tile_mask(
|
|||
|
||||
# build gradients
|
||||
edge_t, edge_l, edge_b, edge_r = edges
|
||||
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [int(not edge_t), 1, 1, int(not edge_b)]
|
||||
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
|
||||
int(not edge_t),
|
||||
1,
|
||||
1,
|
||||
int(not edge_b),
|
||||
]
|
||||
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
|
||||
|
||||
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
|
||||
|
|
|
@ -660,8 +660,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||
region_noise_pred, 2
|
||||
)
|
||||
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
||||
region_noise_pred_text - region_noise_pred_uncond
|
||||
region_noise_pred = (
|
||||
region_noise_pred_uncond
|
||||
+ guidance_scale
|
||||
* (region_noise_pred_text - region_noise_pred_uncond)
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
|
|
@ -502,8 +502,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||
region_noise_pred, 2
|
||||
)
|
||||
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
||||
region_noise_pred_text - region_noise_pred_uncond
|
||||
region_noise_pred = (
|
||||
region_noise_pred_uncond
|
||||
+ guidance_scale
|
||||
* (region_noise_pred_text - region_noise_pred_uncond)
|
||||
)
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
|
|
|
@ -379,6 +379,9 @@ def encode_prompt(
|
|||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
TODO: does not work with SDXL, fix or turn into a pipeline patch
|
||||
"""
|
||||
return [
|
||||
pipe._encode_prompt(
|
||||
remove_tokens(prompt),
|
||||
|
@ -456,7 +459,9 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
|||
return prompt
|
||||
|
||||
|
||||
Region = Tuple[int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str]
|
||||
Region = Tuple[
|
||||
int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
|
||||
]
|
||||
|
||||
|
||||
def parse_region_group(group: Tuple[str, ...]) -> Region:
|
||||
|
@ -475,12 +480,15 @@ def parse_region_group(group: Tuple[str, ...]) -> Region:
|
|||
int(bottom),
|
||||
int(right),
|
||||
float(weight),
|
||||
(float(feather_radius), (
|
||||
(
|
||||
float(feather_radius),
|
||||
(
|
||||
"T" in feather_edges,
|
||||
"L" in feather_edges,
|
||||
"B" in feather_edges,
|
||||
"R" in feather_edges,
|
||||
)),
|
||||
),
|
||||
),
|
||||
prompt,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
from typing import Literal
|
||||
from typing import List, Literal
|
||||
|
||||
NetworkType = Literal["inversion", "lora"]
|
||||
|
||||
|
||||
class NetworkModel:
|
||||
name: str
|
||||
tokens: List[str]
|
||||
type: NetworkType
|
||||
|
||||
def __init__(self, name: str, type: NetworkType) -> None:
|
||||
def __init__(self, name: str, type: NetworkType, tokens=None) -> None:
|
||||
self.name = name
|
||||
self.tokens = tokens or []
|
||||
self.type = type
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"tokens": self.tokens,
|
||||
"type": self.type,
|
||||
}
|
||||
|
|
|
@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
|
|||
# Loaded from extra_models
|
||||
extra_hashes: Dict[str, str] = {}
|
||||
extra_strings: Dict[str, Any] = {}
|
||||
extra_tokens: Dict[str, List[str]] = {}
|
||||
|
||||
|
||||
def get_config_params():
|
||||
|
@ -160,6 +161,7 @@ def load_extras(server: ServerContext):
|
|||
"""
|
||||
global extra_hashes
|
||||
global extra_strings
|
||||
global extra_tokens
|
||||
|
||||
labels = {}
|
||||
strings = {}
|
||||
|
@ -210,6 +212,14 @@ def load_extras(server: ServerContext):
|
|||
else:
|
||||
labels[model_name] = model["label"]
|
||||
|
||||
if "tokens" in model:
|
||||
logger.debug(
|
||||
"collecting tokens for model %s from %s",
|
||||
model_name,
|
||||
file,
|
||||
)
|
||||
extra_tokens[model_name] = model["tokens"]
|
||||
|
||||
if "inversions" in model:
|
||||
for inversion in model["inversions"]:
|
||||
if "label" in inversion:
|
||||
|
@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None:
|
|||
)
|
||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||
network_models.extend(
|
||||
[NetworkModel(model, "inversion") for model in inversion_models]
|
||||
[
|
||||
NetworkModel(model, "inversion", tokens=extra_tokens.get(model, []))
|
||||
for model in inversion_models
|
||||
]
|
||||
)
|
||||
|
||||
lora_models = list_model_globs(
|
||||
|
@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None:
|
|||
base_path=path.join(server.model_path, "lora"),
|
||||
)
|
||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
||||
network_models.extend(
|
||||
[
|
||||
NetworkModel(model, "lora", tokens=extra_tokens.get(model, []))
|
||||
for model in lora_models
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def load_params(server: ServerContext) -> None:
|
||||
|
|
|
@ -10,34 +10,53 @@ $defs:
|
|||
- type: number
|
||||
- type: string
|
||||
|
||||
lora_network:
|
||||
tensor_format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
|
||||
embedding_network:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
format:
|
||||
$ref: "#/$defs/tensor_format"
|
||||
label:
|
||||
type: string
|
||||
weight:
|
||||
type: number
|
||||
|
||||
textual_inversion_network:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
format:
|
||||
model:
|
||||
type: string
|
||||
enum: [concept, embeddings]
|
||||
label:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
token:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: inversion # TODO: add embedding
|
||||
weight:
|
||||
type: number
|
||||
|
||||
lora_network:
|
||||
type: object
|
||||
required: [name, source, type]
|
||||
properties:
|
||||
label:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
enum: [cloneofsimo, sd-scripts]
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
tokens:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: lora
|
||||
weight:
|
||||
type: number
|
||||
|
||||
|
@ -46,8 +65,7 @@ $defs:
|
|||
required: [name, source]
|
||||
properties:
|
||||
format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
$ref: "#/$defs/tensor_format"
|
||||
half:
|
||||
type: boolean
|
||||
label:
|
||||
|
@ -85,7 +103,7 @@ $defs:
|
|||
inversions:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/textual_inversion_network"
|
||||
$ref: "#/$defs/embedding_network"
|
||||
loras:
|
||||
type: array
|
||||
items:
|
||||
|
@ -142,31 +160,6 @@ $defs:
|
|||
source:
|
||||
type: string
|
||||
|
||||
source_network:
|
||||
type: object
|
||||
required: [name, source, type]
|
||||
properties:
|
||||
format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
model:
|
||||
type: string
|
||||
enum: [
|
||||
# inversion
|
||||
concept,
|
||||
embeddings,
|
||||
# lora
|
||||
cloneofsimo,
|
||||
sd-scripts
|
||||
]
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
enum: [inversion, lora]
|
||||
|
||||
translation:
|
||||
type: object
|
||||
additionalProperties: False
|
||||
|
@ -194,7 +187,9 @@ properties:
|
|||
networks:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/source_network"
|
||||
oneOf:
|
||||
- $ref: "#/$defs/lora_network"
|
||||
- $ref: "#/$defs/embedding_network"
|
||||
sources:
|
||||
type: array
|
||||
items:
|
||||
|
|
Binary file not shown.
|
@ -1,5 +1,5 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { TextField } from '@mui/material';
|
||||
import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { Chip, TextField } from '@mui/material';
|
||||
import { Stack } from '@mui/system';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import * as React from 'react';
|
||||
|
@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
|
|||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { QueryMenu } from '../input/QueryMenu.js';
|
||||
import { ModelResponse } from '../../types/api.js';
|
||||
|
||||
const { useContext } = React;
|
||||
|
||||
|
@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
|
|||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
const tokens = splitPrompt(prompt);
|
||||
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const helper = t('input.prompt.tokens', {
|
||||
groups,
|
||||
tokens: tokens.length,
|
||||
});
|
||||
|
||||
function addToken(type: string, name: string, weight = 1.0) {
|
||||
function addNetwork(type: string, name: string, weight = 1.0) {
|
||||
onChange({
|
||||
prompt: `<${type}:${name}:1.0> ${prompt}`,
|
||||
negativePrompt,
|
||||
});
|
||||
}
|
||||
|
||||
function addToken(name: string) {
|
||||
onChange({
|
||||
prompt: `${prompt}, ${name}`,
|
||||
});
|
||||
}
|
||||
|
||||
const networks = extractNetworks(prompt);
|
||||
const tokens = getNetworkTokens(models.data, networks);
|
||||
|
||||
return <Stack spacing={2}>
|
||||
<TextField
|
||||
label={t('parameter.prompt')}
|
||||
helperText={helper}
|
||||
variant='outlined'
|
||||
value={prompt}
|
||||
onChange={(event) => {
|
||||
|
@ -77,6 +79,13 @@ export function PromptInput(props: PromptInputProps) {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<Stack direction='row' spacing={2}>
|
||||
{tokens.map(([token, _weight]) => <Chip
|
||||
color={prompt.includes(token) ? 'primary' : 'default'}
|
||||
label={token}
|
||||
onClick={() => addToken(token)}
|
||||
/>)}
|
||||
</Stack>
|
||||
<TextField
|
||||
label={t('parameter.negativePrompt')}
|
||||
variant='outlined'
|
||||
|
@ -98,7 +107,7 @@ export function PromptInput(props: PromptInputProps) {
|
|||
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
addToken('inversion', name);
|
||||
addNetwork('inversion', name);
|
||||
}}
|
||||
/>
|
||||
<QueryMenu
|
||||
|
@ -110,9 +119,69 @@ export function PromptInput(props: PromptInputProps) {
|
|||
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||
}}
|
||||
onSelect={(name) => {
|
||||
addToken('lora', name);
|
||||
addNetwork('lora', name);
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
</Stack>;
|
||||
}
|
||||
|
||||
export const ANY_TOKEN = /<([^>]+)>/g;
|
||||
|
||||
export type TokenList = Array<[string, number]>;
|
||||
|
||||
export interface PromptNetworks {
|
||||
inversion: TokenList;
|
||||
lora: TokenList;
|
||||
}
|
||||
|
||||
export function extractNetworks(prompt: string): PromptNetworks {
|
||||
const inversion: TokenList = [];
|
||||
const lora: TokenList = [];
|
||||
|
||||
for (const token of prompt.matchAll(ANY_TOKEN)) {
|
||||
const [_whole, match] = Array.from(token);
|
||||
const [type, name, weight, ..._rest] = match.split(':');
|
||||
|
||||
switch (type) {
|
||||
case 'inversion':
|
||||
inversion.push([name, parseFloat(weight)]);
|
||||
break;
|
||||
case 'lora':
|
||||
lora.push([name, parseFloat(weight)]);
|
||||
break;
|
||||
default:
|
||||
// ignore others
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
inversion,
|
||||
lora,
|
||||
};
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
|
||||
const tokens: TokenList = [];
|
||||
|
||||
if (doesExist(models)) {
|
||||
for (const [name, weight] of networks.inversion) {
|
||||
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
|
||||
if (doesExist(model) && model.type === 'inversion') {
|
||||
tokens.push([model.token, weight]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const [name, weight] of networks.lora) {
|
||||
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
|
||||
if (doesExist(model) && model.type === 'lora') {
|
||||
for (const token of mustDefault(model.tokens, [])) {
|
||||
tokens.push([token, weight]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
|
|
@ -39,13 +39,28 @@ export interface ReadyResponse {
|
|||
ready: boolean;
|
||||
}
|
||||
|
||||
export interface NetworkModel {
|
||||
export interface ControlNetwork {
|
||||
name: string;
|
||||
type: 'control' | 'inversion' | 'lora';
|
||||
// TODO: add token
|
||||
// TODO: add layer/token count
|
||||
type: 'control';
|
||||
}
|
||||
|
||||
export interface EmbeddingNetwork {
|
||||
label: string;
|
||||
name: string;
|
||||
token: string;
|
||||
type: 'inversion';
|
||||
// TODO: add layer count
|
||||
}
|
||||
|
||||
export interface LoraNetwork {
|
||||
name: string;
|
||||
label: string;
|
||||
tokens: Array<string>;
|
||||
type: 'lora';
|
||||
}
|
||||
|
||||
export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;
|
||||
|
||||
export interface FilterResponse {
|
||||
mask: Array<string>;
|
||||
source: Array<string>;
|
||||
|
|
Loading…
Reference in New Issue