1
0
Fork 0
onnx-web/api/onnx_web/diffusers/utils.py

548 lines
17 KiB
Python
Raw Normal View History

2023-07-04 21:41:54 +00:00
import random
2023-07-06 23:46:21 +00:00
from copy import deepcopy
from logging import getLogger
from math import ceil
2023-03-16 00:27:29 +00:00
from re import Pattern, compile
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
2023-03-19 14:29:06 +00:00
import torch
from diffusers import OnnxStableDiffusionPipeline
2023-12-03 17:11:23 +00:00
from ..constants import LATENT_CHANNELS, LATENT_FACTOR
from ..params import ImageParams, Size
logger = getLogger(__name__)
MAX_TOKENS_PER_GROUP = 77
ANY_TOKEN = compile(r"\<([^\>]*)\>")
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<(embeddings|inversion):([^:\>]+):(-?[\.|\d]+)\>")
2023-04-29 14:09:25 +00:00
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
2023-11-09 04:04:15 +00:00
REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
2023-11-09 04:04:15 +00:00
)
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w\. ]+)__")
2023-07-04 21:41:54 +00:00
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
2023-04-28 03:50:11 +00:00
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
def expand_interval_ranges(prompt: str) -> str:
def expand_range(match):
(base_token, start, end, step) = match.groups(default=1)
2023-03-08 04:40:17 +00:00
num_tokens = [
f"{base_token}-{i}" for i in range(int(start), int(end), int(step))
]
return " ".join(num_tokens)
return INTERVAL_RANGE.sub(expand_range, prompt)
def expand_alternative_ranges(prompt: str) -> List[str]:
prompt_groups = []
last_end = 0
next_group = ALTERNATIVE_RANGE.search(prompt)
while next_group is not None:
logger.debug("found alternative group in prompt: %s", next_group)
if next_group.start() > last_end:
skipped_prompt = prompt[last_end : next_group.start()]
logger.trace("appending skipped section of prompt: %s", skipped_prompt)
prompt_groups.append([skipped_prompt])
2023-04-23 20:45:49 +00:00
options = next_group.group()[1:-1].split("|")
logger.trace("split up alternative options: %s", options)
prompt_groups.append(options)
last_end = next_group.end()
next_group = ALTERNATIVE_RANGE.search(prompt, last_end)
if last_end < len(prompt):
remaining_prompt = prompt[last_end:]
logger.trace("appending remainder of prompt: %s", remaining_prompt)
prompt_groups.append([remaining_prompt])
prompt_count = max([len(group) for group in prompt_groups])
prompts = []
for i in range(prompt_count):
options = []
for group in prompt_groups:
group_i = i % len(group)
options.append(group[group_i])
2023-04-23 22:16:46 +00:00
prompts.append("".join(options))
return prompts
2023-03-19 13:17:40 +00:00
@torch.no_grad()
def expand_prompt(
self: OnnxStableDiffusionPipeline,
prompt: str,
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
2023-12-03 18:13:45 +00:00
skip_clip_states: int = 0,
) -> np.ndarray:
# self provides:
# tokenizer: CLIPTokenizer
# encoder: OnnxRuntimeModel
2023-03-19 13:43:39 +00:00
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
if len(clip_tokens) > 0:
skip_clip_states = int(clip_tokens[0][1])
logger.info("skipping %s CLIP layers", skip_clip_states)
2023-03-19 13:17:40 +00:00
batch_size = len(prompt) if isinstance(prompt, list) else 1
prompt = expand_interval_ranges(prompt)
# split prompt into 75 token chunks
tokens = self.tokenizer(
prompt,
padding="max_length",
return_tensors="np",
max_length=self.tokenizer.model_max_length,
truncation=False,
)
groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP)
logger.trace("splitting %s into %s groups", tokens.input_ids.shape, groups_count)
groups = []
# np.array_split(tokens.input_ids, groups_count, axis=1)
for i in range(groups_count):
group_start = i * MAX_TOKENS_PER_GROUP
group_end = min(
group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1]
) # or should this be 1?
logger.trace("building group for token slice [%s : %s]", group_start, group_end)
2023-04-23 22:16:46 +00:00
group_size = group_end - group_start
if group_size < MAX_TOKENS_PER_GROUP:
pass # TODO: pad short groups
2023-04-23 22:16:46 +00:00
groups.append(tokens.input_ids[:, group_start:group_end])
# encode each chunk
logger.trace("group token shapes: %s", [t.shape for t in groups])
group_embeds = []
for group in groups:
logger.trace("encoding group: %s", group.shape)
2023-03-19 13:17:40 +00:00
text_result = self.text_encoder(input_ids=group.astype(np.int32))
2023-03-19 14:29:06 +00:00
logger.trace(
2023-04-24 02:16:46 +00:00
"text encoder produced %s outputs: %s",
len(text_result),
[t.shape for t in text_result],
2023-03-19 14:29:06 +00:00
)
2023-03-19 13:17:40 +00:00
last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0:
2023-12-03 18:13:45 +00:00
# TODO: why is this normalized?
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
2023-03-19 14:29:06 +00:00
norm_state = layer_norm(
torch.from_numpy(
hidden_states[-skip_clip_states].astype(np.float32)
).detach()
2023-03-19 14:29:06 +00:00
)
logger.trace(
"normalized results after skipping %s layers: %s",
skip_clip_states,
norm_state.shape,
)
group_embeds.append(
norm_state.numpy().astype(hidden_states[-skip_clip_states].dtype)
)
2023-03-19 13:17:40 +00:00
else:
group_embeds.append(last_state)
# concat those embeds
logger.trace("group embeds shape: %s", [t.shape for t in group_embeds])
prompt_embeds = np.concatenate(group_embeds, axis=1)
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input.input_ids.astype(np.int32)
)[0]
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
logger.trace(
"padding negative prompt to match input: %s, %s, %s extra tokens",
tokens.input_ids.shape,
negative_prompt_embeds.shape,
negative_padding,
)
negative_prompt_embeds = np.pad(
negative_prompt_embeds,
[(0, 0), (0, negative_padding), (0, 0)],
mode="constant",
constant_values=0,
)
negative_prompt_embeds = np.repeat(
negative_prompt_embeds, num_images_per_prompt, axis=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
logger.trace("expanded prompt shape: %s", prompt_embeds.shape)
return prompt_embeds
2023-11-05 22:06:49 +00:00
def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]:
name, weight = group
return (name, float(weight))
2023-03-16 00:27:29 +00:00
def get_tokens_from_prompt(
2023-11-05 22:06:49 +00:00
prompt: str,
pattern: Pattern,
2023-11-06 14:48:35 +00:00
parser=parse_float_group,
2023-03-16 00:27:29 +00:00
) -> Tuple[str, List[Tuple[str, float]]]:
remaining_prompt = prompt
tokens = []
next_match = pattern.search(remaining_prompt)
while next_match is not None:
logger.debug("found token in prompt: %s", next_match)
2023-11-05 22:06:49 +00:00
group = next_match.groups()
tokens.append(parser(group))
# remove this match and look for another
2023-03-16 00:27:29 +00:00
remaining_prompt = (
remaining_prompt[: next_match.start()]
+ remaining_prompt[next_match.end() :]
)
next_match = pattern.search(remaining_prompt)
return (remaining_prompt, tokens)
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, LORA_TOKEN)
def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
2023-12-03 17:11:23 +00:00
def random_seed(generator=None) -> int:
if generator is None:
generator = np.random
return generator.randint(np.iinfo(np.int32).max)
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
"""
From https://www.travelneil.com/stable-diffusion-updates.html.
"""
latents_shape = (
batch,
2023-04-28 03:50:11 +00:00
LATENT_CHANNELS,
size.height // LATENT_FACTOR,
size.width // LATENT_FACTOR,
)
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
def expand_latents(
latents: np.ndarray,
seed: int,
size: Size,
sigma: float = 1.0,
) -> np.ndarray:
batch, _channels, height, width = latents.shape
extra_latents = get_latents_from_seed(seed, size, batch=batch)
extra_latents[:, :, 0:height, 0:width] = latents
return extra_latents * np.float64(sigma)
2023-12-03 17:11:23 +00:00
def resize_latent_shape(
latents: np.ndarray,
2023-12-03 17:48:52 +00:00
size: Tuple[int, int],
2023-12-03 17:11:23 +00:00
) -> Tuple[int, int, int, int]:
2023-12-03 17:48:52 +00:00
return (latents.shape[0], latents.shape[1], *size)
2023-12-03 17:11:23 +00:00
def get_tile_latents(
full_latents: np.ndarray,
seed: int,
size: Size,
dims: Tuple[int, int, int],
) -> np.ndarray:
x, y, tile = dims
2023-04-28 03:50:11 +00:00
t = tile // LATENT_FACTOR
x = max(0, x // LATENT_FACTOR)
y = max(0, y // LATENT_FACTOR)
xt = x + t
yt = y + t
2023-08-20 20:18:47 +00:00
logger.trace(
"getting tile latents: [%s:%s, %s:%s] within %s",
y,
yt,
x,
xt,
full_latents.shape,
)
tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
tile_latents = expand_latents(tile_latents, seed, size)
return tile_latents
def get_scaled_latents(
seed: int,
size: Size,
batch: int = 1,
scale: int = 1,
) -> np.ndarray:
latents = get_latents_from_seed(seed, size, batch=batch)
latents = torch.from_numpy(latents)
scaled = torch.nn.functional.interpolate(
latents, scale_factor=(scale, scale), mode="bilinear"
)
return scaled.numpy()
def parse_prompt(
params: ImageParams,
) -> Tuple[
List[Tuple[str, str]],
List[Tuple[str, float]],
List[Tuple[str, float]],
Tuple[str, str],
]:
2023-07-16 15:46:26 +00:00
"""
TODO: return a more structured format
"""
2023-07-16 15:55:12 +00:00
prompt, loras = get_loras_from_prompt(params.prompt)
prompt, inversions = get_inversions_from_prompt(prompt)
neg_prompt = None
2023-07-16 15:46:26 +00:00
if params.negative_prompt is not None:
2023-07-16 15:55:12 +00:00
neg_prompt, neg_loras = get_loras_from_prompt(params.negative_prompt)
neg_prompt, neg_inversions = get_inversions_from_prompt(neg_prompt)
loras.extend(neg_loras)
inversions.extend(neg_inversions)
prompts = expand_alternative_ranges(prompt)
2023-04-24 00:47:03 +00:00
if neg_prompt is not None:
neg_prompts = expand_alternative_ranges(neg_prompt)
else:
neg_prompts = [None] * len(prompts)
logger.trace("generated prompts: %s, %s", prompts, neg_prompts)
# count these ahead of time, because they will change
prompt_count = len(prompts)
neg_prompt_count = len(neg_prompts)
if prompt_count < neg_prompt_count:
# extend prompts
for i in range(prompt_count, neg_prompt_count):
prompts.append(prompts[i % prompt_count])
elif prompt_count > neg_prompt_count:
# extend neg_prompts
for i in range(neg_prompt_count, prompt_count):
neg_prompts.append(neg_prompts[i % neg_prompt_count])
return list(zip(prompts, neg_prompts)), loras, inversions, (prompt, neg_prompt)
def encode_prompt(
pipe: OnnxStableDiffusionPipeline,
prompt_pairs: List[Tuple[str, str]],
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [
pipe._encode_prompt(
remove_tokens(prompt),
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=remove_tokens(neg_prompt),
)
for prompt, neg_prompt in prompt_pairs
]
2023-07-04 21:41:54 +00:00
def parse_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -> str:
2023-07-04 21:41:54 +00:00
next_match = WILDCARD_TOKEN.search(prompt)
remaining_prompt = prompt
# prep a local copy to avoid mutating the main one
wildcards = deepcopy(wildcards)
2023-07-04 21:41:54 +00:00
random.seed(seed)
while next_match is not None:
logger.debug("found wildcard in prompt: %s", next_match)
2023-07-04 22:12:10 +00:00
name, *rest = next_match.groups()
2023-07-04 21:41:54 +00:00
wildcard = ""
if name in wildcards:
wildcard = pop_random(wildcards.get(name))
else:
2023-07-04 21:41:54 +00:00
logger.warning("unknown wildcard: %s", name)
remaining_prompt = (
remaining_prompt[: next_match.start()]
+ wildcard
+ remaining_prompt[next_match.end() :]
)
2023-07-04 22:12:10 +00:00
next_match = WILDCARD_TOKEN.search(remaining_prompt)
2023-07-04 21:41:54 +00:00
return remaining_prompt
def replace_wildcards(params: ImageParams, wildcards: Dict[str, List[str]]):
params.prompt = parse_wildcards(params.prompt, params.seed, wildcards)
if params.negative_prompt is not None:
params.negative_prompt = parse_wildcards(
params.negative_prompt, params.seed, wildcards
)
def pop_random(list: List[str]) -> str:
"""
From https://stackoverflow.com/a/14088129
"""
i = random.randrange(len(list))
list[i], list[-1] = list[-1], list[i]
2023-07-06 23:46:21 +00:00
return list.pop()
def repair_nan(tile: np.ndarray) -> np.ndarray:
flat_tile = tile.flatten()
flat_mask = np.isnan(flat_tile)
if np.any(flat_mask):
logger.warning("repairing NaN values in image")
indices = np.where(~flat_mask, np.arange(flat_mask.shape[0]), 0)
np.maximum.accumulate(indices, out=indices)
return np.reshape(flat_tile[indices], tile.shape)
else:
return tile
2023-08-26 04:33:41 +00:00
def slice_prompt(prompt: str, slice: int) -> str:
if "||" in prompt:
parts = prompt.split("||")
return parts[min(slice, len(parts) - 1)]
else:
2023-08-26 04:36:30 +00:00
return prompt
2023-11-12 22:38:56 +00:00
Region = Tuple[
int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
]
2023-11-05 21:46:37 +00:00
def parse_region_group(group: Tuple[str, ...]) -> Region:
top, left, bottom, right, weight, feather, prompt = group
# break down the feather section
feather_radius, *feather_edges = feather.split("_")
2023-11-12 02:22:08 +00:00
if len(feather_edges) == 0:
feather_edges = "TLBR"
else:
feather_edges = "".join(feather_edges)
2023-11-09 04:04:15 +00:00
return (
int(top),
int(left),
int(bottom),
int(right),
float(weight),
2023-11-12 22:38:56 +00:00
(
float(feather_radius),
(
"T" in feather_edges,
"L" in feather_edges,
"B" in feather_edges,
"R" in feather_edges,
),
),
2023-11-09 04:04:15 +00:00
prompt,
)
2023-11-05 22:38:43 +00:00
2023-11-05 22:25:48 +00:00
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
2023-11-05 22:38:43 +00:00
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
2023-11-11 00:37:42 +00:00
Reseed = Tuple[int, int, int, int, int]
def parse_reseed_group(group) -> Region:
top, left, bottom, right, seed = group
return (
int(top),
int(left),
int(bottom),
int(right),
int(seed),
)
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
def skip_group(group) -> Any:
return group
def remove_tokens(prompt: Optional[str]) -> Optional[str]:
if prompt is None:
return prompt
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN, parser=skip_group)
return remainder