1
0
Fork 0

Merge branch 'feat/00-prompt-tokens'

This commit is contained in:
Sean Sube 2023-11-12 16:40:53 -06:00
commit e653560f03
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 201 additions and 81 deletions

View File

@ -17,7 +17,7 @@ with a CPU fallback capable of running on laptop-class machines.
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png) ![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png)
## Features ## Features

View File

@ -115,7 +115,12 @@ def make_tile_mask(
# build gradients # build gradients
edge_t, edge_l, edge_b, edge_r = edges edge_t, edge_l, edge_b, edge_r = edges
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [int(not edge_t), 1, 1, int(not edge_b)] grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
int(not edge_t),
1,
1,
int(not edge_b),
]
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y) logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)] mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]

View File

@ -660,8 +660,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
region_noise_pred_uncond, region_noise_pred_text = np.split( region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2 region_noise_pred, 2
) )
region_noise_pred = region_noise_pred_uncond + guidance_scale * ( region_noise_pred = (
region_noise_pred_text - region_noise_pred_uncond region_noise_pred_uncond
+ guidance_scale
* (region_noise_pred_text - region_noise_pred_uncond)
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1

View File

@ -502,8 +502,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
region_noise_pred_uncond, region_noise_pred_text = np.split( region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2 region_noise_pred, 2
) )
region_noise_pred = region_noise_pred_uncond + guidance_scale * ( region_noise_pred = (
region_noise_pred_text - region_noise_pred_uncond region_noise_pred_uncond
+ guidance_scale
* (region_noise_pred_text - region_noise_pred_uncond)
) )
if guidance_rescale > 0.0: if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

View File

@ -379,6 +379,9 @@ def encode_prompt(
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]: ) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [ return [
pipe._encode_prompt( pipe._encode_prompt(
remove_tokens(prompt), remove_tokens(prompt),
@ -456,7 +459,9 @@ def slice_prompt(prompt: str, slice: int) -> str:
return prompt return prompt
Region = Tuple[int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str] Region = Tuple[
int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
]
def parse_region_group(group: Tuple[str, ...]) -> Region: def parse_region_group(group: Tuple[str, ...]) -> Region:
@ -475,12 +480,15 @@ def parse_region_group(group: Tuple[str, ...]) -> Region:
int(bottom), int(bottom),
int(right), int(right),
float(weight), float(weight),
(float(feather_radius), ( (
"T" in feather_edges, float(feather_radius),
"L" in feather_edges, (
"B" in feather_edges, "T" in feather_edges,
"R" in feather_edges, "L" in feather_edges,
)), "B" in feather_edges,
"R" in feather_edges,
),
),
prompt, prompt,
) )

View File

@ -1,18 +1,21 @@
from typing import Literal from typing import List, Literal
NetworkType = Literal["inversion", "lora"] NetworkType = Literal["inversion", "lora"]
class NetworkModel: class NetworkModel:
name: str name: str
tokens: List[str]
type: NetworkType type: NetworkType
def __init__(self, name: str, type: NetworkType) -> None: def __init__(self, name: str, type: NetworkType, tokens=None) -> None:
self.name = name self.name = name
self.tokens = tokens or []
self.type = type self.type = type
def tojson(self): def tojson(self):
return { return {
"name": self.name, "name": self.name,
"tokens": self.tokens,
"type": self.type, "type": self.type,
} }

View File

@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
# Loaded from extra_models # Loaded from extra_models
extra_hashes: Dict[str, str] = {} extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {} extra_strings: Dict[str, Any] = {}
extra_tokens: Dict[str, List[str]] = {}
def get_config_params(): def get_config_params():
@ -160,6 +161,7 @@ def load_extras(server: ServerContext):
""" """
global extra_hashes global extra_hashes
global extra_strings global extra_strings
global extra_tokens
labels = {} labels = {}
strings = {} strings = {}
@ -210,6 +212,14 @@ def load_extras(server: ServerContext):
else: else:
labels[model_name] = model["label"] labels[model_name] = model["label"]
if "tokens" in model:
logger.debug(
"collecting tokens for model %s from %s",
model_name,
file,
)
extra_tokens[model_name] = model["tokens"]
if "inversions" in model: if "inversions" in model:
for inversion in model["inversions"]: for inversion in model["inversions"]:
if "label" in inversion: if "label" in inversion:
@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None:
) )
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend( network_models.extend(
[NetworkModel(model, "inversion") for model in inversion_models] [
NetworkModel(model, "inversion", tokens=extra_tokens.get(model, []))
for model in inversion_models
]
) )
lora_models = list_model_globs( lora_models = list_model_globs(
@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None:
base_path=path.join(server.model_path, "lora"), base_path=path.join(server.model_path, "lora"),
) )
logger.debug("loaded LoRA models from disk: %s", lora_models) logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models]) network_models.extend(
[
NetworkModel(model, "lora", tokens=extra_tokens.get(model, []))
for model in lora_models
]
)
def load_params(server: ServerContext) -> None: def load_params(server: ServerContext) -> None:

View File

@ -10,34 +10,53 @@ $defs:
- type: number - type: number
- type: string - type: string
lora_network: tensor_format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
embedding_network:
type: object type: object
required: [name, source] required: [name, source]
properties: properties:
name: format:
type: string $ref: "#/$defs/tensor_format"
source:
type: string
label: label:
type: string type: string
weight: model:
type: number
textual_inversion_network:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
format:
type: string type: string
enum: [concept, embeddings] enum: [concept, embeddings]
label: name:
type: string
source:
type: string type: string
token: token:
type: string type: string
type:
type: string
const: inversion # TODO: add embedding
weight:
type: number
lora_network:
type: object
required: [name, source, type]
properties:
label:
type: string
model:
type: string
enum: [cloneofsimo, sd-scripts]
name:
type: string
source:
type: string
tokens:
type: array
items:
type: string
type:
type: string
const: lora
weight: weight:
type: number type: number
@ -46,8 +65,7 @@ $defs:
required: [name, source] required: [name, source]
properties: properties:
format: format:
type: string $ref: "#/$defs/tensor_format"
enum: [bin, ckpt, onnx, pt, pth, safetensors]
half: half:
type: boolean type: boolean
label: label:
@ -85,7 +103,7 @@ $defs:
inversions: inversions:
type: array type: array
items: items:
$ref: "#/$defs/textual_inversion_network" $ref: "#/$defs/embedding_network"
loras: loras:
type: array type: array
items: items:
@ -142,31 +160,6 @@ $defs:
source: source:
type: string type: string
source_network:
type: object
required: [name, source, type]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name:
type: string
source:
type: string
type:
type: string
enum: [inversion, lora]
translation: translation:
type: object type: object
additionalProperties: False additionalProperties: False
@ -194,7 +187,9 @@ properties:
networks: networks:
type: array type: array
items: items:
$ref: "#/$defs/source_network" oneOf:
- $ref: "#/$defs/lora_network"
- $ref: "#/$defs/embedding_network"
sources: sources:
type: array type: array
items: items:

BIN
docs/readme-sdxl.png (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -1,5 +1,5 @@
import { mustExist } from '@apextoaster/js-utils'; import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { TextField } from '@mui/material'; import { Chip, TextField } from '@mui/material';
import { Stack } from '@mui/system'; import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query'; import { useQuery } from '@tanstack/react-query';
import * as React from 'react'; import * as React from 'react';
@ -10,6 +10,7 @@ import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js'; import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js'; import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';
const { useContext } = React; const { useContext } = React;
@ -48,26 +49,27 @@ export function PromptInput(props: PromptInputProps) {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
const tokens = splitPrompt(prompt);
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
const { t } = useTranslation(); const { t } = useTranslation();
const helper = t('input.prompt.tokens', {
groups,
tokens: tokens.length,
});
function addToken(type: string, name: string, weight = 1.0) { function addNetwork(type: string, name: string, weight = 1.0) {
onChange({ onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`, prompt: `<${type}:${name}:1.0> ${prompt}`,
negativePrompt, negativePrompt,
}); });
} }
function addToken(name: string) {
onChange({
prompt: `${prompt}, ${name}`,
});
}
const networks = extractNetworks(prompt);
const tokens = getNetworkTokens(models.data, networks);
return <Stack spacing={2}> return <Stack spacing={2}>
<TextField <TextField
label={t('parameter.prompt')} label={t('parameter.prompt')}
helperText={helper}
variant='outlined' variant='outlined'
value={prompt} value={prompt}
onChange={(event) => { onChange={(event) => {
@ -77,6 +79,13 @@ export function PromptInput(props: PromptInputProps) {
}); });
}} }}
/> />
<Stack direction='row' spacing={2}>
{tokens.map(([token, _weight]) => <Chip
color={prompt.includes(token) ? 'primary' : 'default'}
label={token}
onClick={() => addToken(token)}
/>)}
</Stack>
<TextField <TextField
label={t('parameter.negativePrompt')} label={t('parameter.negativePrompt')}
variant='outlined' variant='outlined'
@ -98,7 +107,7 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}} }}
onSelect={(name) => { onSelect={(name) => {
addToken('inversion', name); addNetwork('inversion', name);
}} }}
/> />
<QueryMenu <QueryMenu
@ -110,9 +119,69 @@ export function PromptInput(props: PromptInputProps) {
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name), selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}} }}
onSelect={(name) => { onSelect={(name) => {
addToken('lora', name); addNetwork('lora', name);
}} }}
/> />
</Stack> </Stack>
</Stack>; </Stack>;
} }
export const ANY_TOKEN = /<([^>]+)>/g;
export type TokenList = Array<[string, number]>;
export interface PromptNetworks {
inversion: TokenList;
lora: TokenList;
}
export function extractNetworks(prompt: string): PromptNetworks {
const inversion: TokenList = [];
const lora: TokenList = [];
for (const token of prompt.matchAll(ANY_TOKEN)) {
const [_whole, match] = Array.from(token);
const [type, name, weight, ..._rest] = match.split(':');
switch (type) {
case 'inversion':
inversion.push([name, parseFloat(weight)]);
break;
case 'lora':
lora.push([name, parseFloat(weight)]);
break;
default:
// ignore others
}
}
return {
inversion,
lora,
};
}
// eslint-disable-next-line sonarjs/cognitive-complexity
export function getNetworkTokens(models: Maybe<ModelResponse>, networks: PromptNetworks): TokenList {
const tokens: TokenList = [];
if (doesExist(models)) {
for (const [name, weight] of networks.inversion) {
const model = models.networks.find((it) => it.type === 'inversion' && it.name === name);
if (doesExist(model) && model.type === 'inversion') {
tokens.push([model.token, weight]);
}
}
for (const [name, weight] of networks.lora) {
const model = models.networks.find((it) => it.type === 'lora' && it.name === name);
if (doesExist(model) && model.type === 'lora') {
for (const token of mustDefault(model.tokens, [])) {
tokens.push([token, weight]);
}
}
}
}
return tokens;
}

View File

@ -39,13 +39,28 @@ export interface ReadyResponse {
ready: boolean; ready: boolean;
} }
export interface NetworkModel { export interface ControlNetwork {
name: string; name: string;
type: 'control' | 'inversion' | 'lora'; type: 'control';
// TODO: add token
// TODO: add layer/token count
} }
export interface EmbeddingNetwork {
label: string;
name: string;
token: string;
type: 'inversion';
// TODO: add layer count
}
export interface LoraNetwork {
name: string;
label: string;
tokens: Array<string>;
type: 'lora';
}
export type NetworkModel = EmbeddingNetwork | LoraNetwork | ControlNetwork;
export interface FilterResponse { export interface FilterResponse {
mask: Array<string>; mask: Array<string>;
source: Array<string>; source: Array<string>;