2023-03-08 01:00:25 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from math import ceil
|
2023-03-16 00:27:29 +00:00
|
|
|
from re import Pattern, compile
|
2023-03-15 13:30:31 +00:00
|
|
|
from typing import List, Optional, Tuple
|
2023-03-08 01:00:25 +00:00
|
|
|
|
|
|
|
import numpy as np
|
2023-03-19 14:29:06 +00:00
|
|
|
import torch
|
2023-03-08 01:00:25 +00:00
|
|
|
from diffusers import OnnxStableDiffusionPipeline
|
|
|
|
|
2023-04-23 20:03:11 +00:00
|
|
|
from ..params import ImageParams, Size
|
|
|
|
|
2023-03-08 01:00:25 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-04-28 03:50:11 +00:00
|
|
|
LATENT_CHANNELS = 4
|
|
|
|
LATENT_FACTOR = 8
|
2023-04-23 20:03:11 +00:00
|
|
|
MAX_TOKENS_PER_GROUP = 77
|
2023-03-08 02:48:26 +00:00
|
|
|
|
2023-03-19 14:13:40 +00:00
|
|
|
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
2023-04-29 14:09:25 +00:00
|
|
|
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
|
|
|
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
2023-04-23 20:03:11 +00:00
|
|
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
2023-04-28 03:50:11 +00:00
|
|
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
2023-03-08 02:48:26 +00:00
|
|
|
|
|
|
|
|
2023-04-23 20:03:11 +00:00
|
|
|
def expand_interval_ranges(prompt: str) -> str:
|
2023-03-08 02:48:26 +00:00
|
|
|
def expand_range(match):
|
|
|
|
(base_token, start, end, step) = match.groups(default=1)
|
2023-03-08 04:40:17 +00:00
|
|
|
num_tokens = [
|
|
|
|
f"{base_token}-{i}" for i in range(int(start), int(end), int(step))
|
|
|
|
]
|
2023-03-08 02:48:26 +00:00
|
|
|
return " ".join(num_tokens)
|
|
|
|
|
2023-04-23 20:03:11 +00:00
|
|
|
return INTERVAL_RANGE.sub(expand_range, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
def expand_alternative_ranges(prompt: str) -> List[str]:
|
|
|
|
prompt_groups = []
|
|
|
|
|
2023-04-23 20:18:51 +00:00
|
|
|
last_end = 0
|
|
|
|
next_group = ALTERNATIVE_RANGE.search(prompt)
|
2023-04-23 20:03:11 +00:00
|
|
|
while next_group is not None:
|
|
|
|
logger.debug("found alternative group in prompt: %s", next_group)
|
|
|
|
|
2023-04-23 22:57:13 +00:00
|
|
|
if next_group.start() > last_end:
|
|
|
|
skipped_prompt = prompt[last_end : next_group.start()]
|
2023-04-23 20:18:51 +00:00
|
|
|
logger.trace("appending skipped section of prompt: %s", skipped_prompt)
|
|
|
|
prompt_groups.append([skipped_prompt])
|
|
|
|
|
2023-04-23 20:45:49 +00:00
|
|
|
options = next_group.group()[1:-1].split("|")
|
|
|
|
logger.trace("split up alternative options: %s", options)
|
|
|
|
prompt_groups.append(options)
|
|
|
|
|
2023-04-23 20:18:51 +00:00
|
|
|
last_end = next_group.end()
|
|
|
|
next_group = ALTERNATIVE_RANGE.search(prompt, last_end)
|
|
|
|
|
|
|
|
if last_end < len(prompt):
|
|
|
|
remaining_prompt = prompt[last_end:]
|
|
|
|
logger.trace("appending remainder of prompt: %s", remaining_prompt)
|
|
|
|
prompt_groups.append([remaining_prompt])
|
|
|
|
|
2023-04-23 20:03:11 +00:00
|
|
|
prompt_count = max([len(group) for group in prompt_groups])
|
|
|
|
prompts = []
|
|
|
|
for i in range(prompt_count):
|
|
|
|
options = []
|
|
|
|
for group in prompt_groups:
|
|
|
|
group_i = i % len(group)
|
|
|
|
options.append(group[group_i])
|
|
|
|
|
2023-04-23 22:16:46 +00:00
|
|
|
prompts.append("".join(options))
|
2023-04-23 20:03:11 +00:00
|
|
|
|
|
|
|
return prompts
|
2023-03-08 01:00:25 +00:00
|
|
|
|
|
|
|
|
2023-03-19 13:17:40 +00:00
|
|
|
@torch.no_grad()
|
2023-03-08 01:00:25 +00:00
|
|
|
def expand_prompt(
|
|
|
|
self: OnnxStableDiffusionPipeline,
|
|
|
|
prompt: str,
|
|
|
|
num_images_per_prompt: int,
|
|
|
|
do_classifier_free_guidance: bool,
|
|
|
|
negative_prompt: Optional[str] = None,
|
2023-04-15 20:56:02 +00:00
|
|
|
prompt_embeds: Optional[np.ndarray] = None,
|
|
|
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
2023-03-19 13:43:39 +00:00
|
|
|
skip_clip_states: Optional[int] = 0,
|
2023-03-08 01:00:25 +00:00
|
|
|
) -> "np.NDArray":
|
|
|
|
# self provides:
|
|
|
|
# tokenizer: CLIPTokenizer
|
|
|
|
# encoder: OnnxRuntimeModel
|
|
|
|
|
2023-03-19 13:43:39 +00:00
|
|
|
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
|
|
|
|
if len(clip_tokens) > 0:
|
|
|
|
skip_clip_states = int(clip_tokens[0][1])
|
|
|
|
logger.info("skipping %s CLIP layers", skip_clip_states)
|
2023-03-19 13:17:40 +00:00
|
|
|
|
2023-03-08 01:00:25 +00:00
|
|
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
2023-04-23 20:03:11 +00:00
|
|
|
prompt = expand_interval_ranges(prompt)
|
2023-03-08 01:00:25 +00:00
|
|
|
|
|
|
|
# split prompt into 75 token chunks
|
|
|
|
tokens = self.tokenizer(
|
|
|
|
prompt,
|
|
|
|
padding="max_length",
|
|
|
|
return_tensors="np",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
truncation=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP)
|
2023-03-18 22:32:32 +00:00
|
|
|
logger.trace("splitting %s into %s groups", tokens.input_ids.shape, groups_count)
|
2023-03-08 01:00:25 +00:00
|
|
|
|
|
|
|
groups = []
|
|
|
|
# np.array_split(tokens.input_ids, groups_count, axis=1)
|
|
|
|
for i in range(groups_count):
|
|
|
|
group_start = i * MAX_TOKENS_PER_GROUP
|
|
|
|
group_end = min(
|
|
|
|
group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1]
|
|
|
|
) # or should this be 1?
|
2023-03-18 22:32:32 +00:00
|
|
|
logger.trace("building group for token slice [%s : %s]", group_start, group_end)
|
2023-04-23 22:16:46 +00:00
|
|
|
|
|
|
|
group_size = group_end - group_start
|
|
|
|
if group_size < MAX_TOKENS_PER_GROUP:
|
2023-04-23 22:33:13 +00:00
|
|
|
pass # TODO: pad short groups
|
2023-04-23 22:16:46 +00:00
|
|
|
|
2023-03-08 01:00:25 +00:00
|
|
|
groups.append(tokens.input_ids[:, group_start:group_end])
|
|
|
|
|
|
|
|
# encode each chunk
|
2023-03-18 22:32:32 +00:00
|
|
|
logger.trace("group token shapes: %s", [t.shape for t in groups])
|
2023-03-08 01:00:25 +00:00
|
|
|
group_embeds = []
|
|
|
|
for group in groups:
|
2023-03-18 22:32:32 +00:00
|
|
|
logger.trace("encoding group: %s", group.shape)
|
2023-03-19 13:17:40 +00:00
|
|
|
|
|
|
|
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
2023-03-19 14:29:06 +00:00
|
|
|
logger.trace(
|
2023-04-24 02:16:46 +00:00
|
|
|
"text encoder produced %s outputs: %s",
|
|
|
|
len(text_result),
|
|
|
|
[t.shape for t in text_result],
|
2023-03-19 14:29:06 +00:00
|
|
|
)
|
2023-03-19 13:17:40 +00:00
|
|
|
|
|
|
|
last_state, _pooled_output, *hidden_states = text_result
|
2023-03-19 13:40:06 +00:00
|
|
|
if skip_clip_states > 0:
|
|
|
|
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
2023-03-19 14:29:06 +00:00
|
|
|
norm_state = layer_norm(
|
2023-04-01 23:14:18 +00:00
|
|
|
torch.from_numpy(
|
|
|
|
hidden_states[-skip_clip_states].astype(np.float32)
|
|
|
|
).detach()
|
2023-03-19 14:29:06 +00:00
|
|
|
)
|
|
|
|
logger.trace(
|
|
|
|
"normalized results after skipping %s layers: %s",
|
|
|
|
skip_clip_states,
|
|
|
|
norm_state.shape,
|
|
|
|
)
|
2023-04-01 23:14:18 +00:00
|
|
|
group_embeds.append(
|
|
|
|
norm_state.numpy().astype(hidden_states[-skip_clip_states].dtype)
|
|
|
|
)
|
2023-03-19 13:17:40 +00:00
|
|
|
else:
|
|
|
|
group_embeds.append(last_state)
|
2023-03-08 01:00:25 +00:00
|
|
|
|
|
|
|
# concat those embeds
|
2023-03-18 22:32:32 +00:00
|
|
|
logger.trace("group embeds shape: %s", [t.shape for t in group_embeds])
|
2023-03-08 01:00:25 +00:00
|
|
|
prompt_embeds = np.concatenate(group_embeds, axis=1)
|
|
|
|
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
|
|
|
|
|
|
|
|
# get unconditional embeddings for classifier free guidance
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
uncond_tokens: List[str]
|
|
|
|
if negative_prompt is None:
|
|
|
|
uncond_tokens = [""] * batch_size
|
|
|
|
elif type(prompt) is not type(negative_prompt):
|
|
|
|
raise TypeError(
|
|
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
|
|
f" {type(prompt)}."
|
|
|
|
)
|
|
|
|
elif isinstance(negative_prompt, str):
|
|
|
|
uncond_tokens = [negative_prompt] * batch_size
|
|
|
|
elif batch_size != len(negative_prompt):
|
|
|
|
raise ValueError(
|
|
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
|
|
" the batch size of `prompt`."
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
uncond_tokens = negative_prompt
|
|
|
|
|
|
|
|
uncond_input = self.tokenizer(
|
|
|
|
uncond_tokens,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="np",
|
|
|
|
)
|
|
|
|
negative_prompt_embeds = self.text_encoder(
|
|
|
|
input_ids=uncond_input.input_ids.astype(np.int32)
|
|
|
|
)[0]
|
|
|
|
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
|
2023-03-18 22:32:32 +00:00
|
|
|
logger.trace(
|
2023-03-08 01:00:25 +00:00
|
|
|
"padding negative prompt to match input: %s, %s, %s extra tokens",
|
|
|
|
tokens.input_ids.shape,
|
|
|
|
negative_prompt_embeds.shape,
|
|
|
|
negative_padding,
|
|
|
|
)
|
|
|
|
negative_prompt_embeds = np.pad(
|
|
|
|
negative_prompt_embeds,
|
|
|
|
[(0, 0), (0, negative_padding), (0, 0)],
|
|
|
|
mode="constant",
|
2023-04-25 02:50:36 +00:00
|
|
|
constant_values=0,
|
2023-03-08 01:00:25 +00:00
|
|
|
)
|
|
|
|
negative_prompt_embeds = np.repeat(
|
|
|
|
negative_prompt_embeds, num_images_per_prompt, axis=0
|
|
|
|
)
|
|
|
|
|
|
|
|
# For classifier free guidance, we need to do two forward passes.
|
|
|
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
|
|
# to avoid doing two forward passes
|
|
|
|
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
|
|
|
|
|
2023-03-18 22:32:32 +00:00
|
|
|
logger.trace("expanded prompt shape: %s", prompt_embeds.shape)
|
2023-03-08 01:00:25 +00:00
|
|
|
return prompt_embeds
|
2023-03-15 13:30:31 +00:00
|
|
|
|
|
|
|
|
2023-03-16 00:27:29 +00:00
|
|
|
def get_tokens_from_prompt(
|
2023-03-18 04:07:10 +00:00
|
|
|
prompt: str, pattern: Pattern
|
2023-03-16 00:27:29 +00:00
|
|
|
) -> Tuple[str, List[Tuple[str, float]]]:
|
2023-03-15 13:51:12 +00:00
|
|
|
"""
|
|
|
|
TODO: replace with Arpeggio
|
|
|
|
"""
|
2023-03-15 13:30:31 +00:00
|
|
|
remaining_prompt = prompt
|
|
|
|
|
2023-03-15 13:51:12 +00:00
|
|
|
tokens = []
|
|
|
|
next_match = pattern.search(remaining_prompt)
|
2023-03-15 13:30:31 +00:00
|
|
|
while next_match is not None:
|
2023-03-15 13:51:12 +00:00
|
|
|
logger.debug("found token in prompt: %s", next_match)
|
2023-03-15 13:30:31 +00:00
|
|
|
name, weight = next_match.groups()
|
2023-03-15 13:51:12 +00:00
|
|
|
tokens.append((name, float(weight)))
|
2023-03-15 13:30:31 +00:00
|
|
|
# remove this match and look for another
|
2023-03-16 00:27:29 +00:00
|
|
|
remaining_prompt = (
|
|
|
|
remaining_prompt[: next_match.start()]
|
|
|
|
+ remaining_prompt[next_match.end() :]
|
|
|
|
)
|
2023-03-15 13:51:12 +00:00
|
|
|
next_match = pattern.search(remaining_prompt)
|
|
|
|
|
|
|
|
return (remaining_prompt, tokens)
|
|
|
|
|
|
|
|
|
|
|
|
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
|
|
|
return get_tokens_from_prompt(prompt, LORA_TOKEN)
|
|
|
|
|
2023-03-15 13:30:31 +00:00
|
|
|
|
2023-03-15 13:51:12 +00:00
|
|
|
def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
|
|
|
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
|
2023-04-23 20:03:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
|
|
|
"""
|
|
|
|
From https://www.travelneil.com/stable-diffusion-updates.html.
|
|
|
|
"""
|
|
|
|
latents_shape = (
|
|
|
|
batch,
|
2023-04-28 03:50:11 +00:00
|
|
|
LATENT_CHANNELS,
|
|
|
|
size.height // LATENT_FACTOR,
|
|
|
|
size.width // LATENT_FACTOR,
|
2023-04-23 20:03:11 +00:00
|
|
|
)
|
|
|
|
rng = np.random.default_rng(seed)
|
|
|
|
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
|
|
|
return image_latents
|
|
|
|
|
|
|
|
|
|
|
|
def get_tile_latents(
|
2023-06-06 13:14:09 +00:00
|
|
|
full_latents: np.ndarray,
|
|
|
|
dims: Tuple[int, int, int],
|
|
|
|
size: Size,
|
2023-04-23 20:03:11 +00:00
|
|
|
) -> np.ndarray:
|
|
|
|
x, y, tile = dims
|
2023-04-28 03:50:11 +00:00
|
|
|
t = tile // LATENT_FACTOR
|
|
|
|
x = x // LATENT_FACTOR
|
|
|
|
y = y // LATENT_FACTOR
|
2023-04-23 20:03:11 +00:00
|
|
|
xt = x + t
|
|
|
|
yt = y + t
|
|
|
|
|
2023-06-06 13:14:09 +00:00
|
|
|
mx = size.width // LATENT_FACTOR
|
|
|
|
my = size.height // LATENT_FACTOR
|
|
|
|
|
2023-06-06 04:44:42 +00:00
|
|
|
tile_latents = full_latents[:, :, y:yt, x:xt]
|
|
|
|
|
|
|
|
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
2023-06-06 13:14:09 +00:00
|
|
|
px = mx - tile_latents.shape[3]
|
|
|
|
py = my - tile_latents.shape[2]
|
2023-06-06 04:44:42 +00:00
|
|
|
|
2023-06-07 12:42:56 +00:00
|
|
|
tile_latents = np.pad(
|
|
|
|
tile_latents, ((0, 0), (0, 0), (0, py), (0, px)), mode="reflect"
|
|
|
|
)
|
2023-06-06 04:44:42 +00:00
|
|
|
|
|
|
|
return tile_latents
|
2023-04-23 20:03:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_scaled_latents(
|
|
|
|
seed: int,
|
|
|
|
size: Size,
|
|
|
|
batch: int = 1,
|
|
|
|
scale: int = 1,
|
|
|
|
) -> np.ndarray:
|
|
|
|
latents = get_latents_from_seed(seed, size, batch=batch)
|
|
|
|
latents = torch.from_numpy(latents)
|
|
|
|
|
|
|
|
scaled = torch.nn.functional.interpolate(
|
|
|
|
latents, scale_factor=(scale, scale), mode="bilinear"
|
|
|
|
)
|
|
|
|
return scaled.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
def parse_prompt(
|
|
|
|
params: ImageParams,
|
|
|
|
) -> Tuple[List[Tuple[str, str]], List[Tuple[str, float]], List[Tuple[str, float]]]:
|
|
|
|
prompt, loras = get_loras_from_prompt(params.input_prompt)
|
|
|
|
prompt, inversions = get_inversions_from_prompt(prompt)
|
|
|
|
params.prompt = prompt
|
|
|
|
|
|
|
|
neg_prompt = None
|
|
|
|
if params.input_negative_prompt is not None:
|
|
|
|
neg_prompt, neg_loras = get_loras_from_prompt(params.input_negative_prompt)
|
|
|
|
neg_prompt, neg_inversions = get_inversions_from_prompt(neg_prompt)
|
|
|
|
params.negative_prompt = neg_prompt
|
|
|
|
|
|
|
|
# TODO: check whether these need to be * -1
|
|
|
|
loras.extend(neg_loras)
|
|
|
|
inversions.extend(neg_inversions)
|
|
|
|
|
|
|
|
prompts = expand_alternative_ranges(prompt)
|
2023-04-24 00:47:03 +00:00
|
|
|
if neg_prompt is not None:
|
|
|
|
neg_prompts = expand_alternative_ranges(neg_prompt)
|
|
|
|
else:
|
|
|
|
neg_prompts = [None] * len(prompts)
|
|
|
|
|
2023-04-23 20:03:11 +00:00
|
|
|
logger.trace("generated prompts: %s, %s", prompts, neg_prompts)
|
|
|
|
|
|
|
|
# count these ahead of time, because they will change
|
|
|
|
prompt_count = len(prompts)
|
|
|
|
neg_prompt_count = len(neg_prompts)
|
|
|
|
|
|
|
|
if prompt_count < neg_prompt_count:
|
|
|
|
# extend prompts
|
|
|
|
for i in range(prompt_count, neg_prompt_count):
|
|
|
|
prompts.append(prompts[i % prompt_count])
|
|
|
|
elif prompt_count > neg_prompt_count:
|
|
|
|
# extend neg_prompts
|
|
|
|
for i in range(neg_prompt_count, prompt_count):
|
|
|
|
neg_prompts.append(neg_prompts[i % neg_prompt_count])
|
|
|
|
|
|
|
|
return list(zip(prompts, neg_prompts)), loras, inversions
|
|
|
|
|
|
|
|
|
|
|
|
def encode_prompt(
|
|
|
|
pipe: OnnxStableDiffusionPipeline,
|
|
|
|
prompt_pairs: List[Tuple[str, str]],
|
|
|
|
num_images_per_prompt: int = 1,
|
|
|
|
do_classifier_free_guidance: bool = True,
|
|
|
|
) -> List[np.ndarray]:
|
|
|
|
return [
|
|
|
|
pipe._encode_prompt(
|
|
|
|
prompt,
|
|
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
|
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
|
|
negative_prompt=neg_prompt,
|
|
|
|
)
|
|
|
|
for prompt, neg_prompt in prompt_pairs
|
|
|
|
]
|