1
0
Fork 0
onnx-web/api/onnx_web/diffusers/utils.py

196 lines
6.9 KiB
Python
Raw Normal View History

from logging import getLogger
from math import ceil
2023-03-16 00:27:29 +00:00
from re import Pattern, compile
from typing import List, Optional, Tuple
import numpy as np
2023-03-19 14:29:06 +00:00
import torch
from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__)
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([-\w]+):(-?[\.|\d]+)\>")
MAX_TOKENS_PER_GROUP = 77
PATTERN_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
def expand_prompt_ranges(prompt: str) -> str:
def expand_range(match):
(base_token, start, end, step) = match.groups(default=1)
2023-03-08 04:40:17 +00:00
num_tokens = [
f"{base_token}-{i}" for i in range(int(start), int(end), int(step))
]
return " ".join(num_tokens)
return PATTERN_RANGE.sub(expand_range, prompt)
2023-03-19 13:17:40 +00:00
@torch.no_grad()
def expand_prompt(
self: OnnxStableDiffusionPipeline,
prompt: str,
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[str] = None,
2023-03-19 13:43:39 +00:00
skip_clip_states: Optional[int] = 0,
) -> "np.NDArray":
# self provides:
# tokenizer: CLIPTokenizer
# encoder: OnnxRuntimeModel
2023-03-19 13:43:39 +00:00
prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN)
if len(clip_tokens) > 0:
skip_clip_states = int(clip_tokens[0][1])
logger.info("skipping %s CLIP layers", skip_clip_states)
2023-03-19 13:17:40 +00:00
batch_size = len(prompt) if isinstance(prompt, list) else 1
prompt = expand_prompt_ranges(prompt)
# split prompt into 75 token chunks
tokens = self.tokenizer(
prompt,
padding="max_length",
return_tensors="np",
max_length=self.tokenizer.model_max_length,
truncation=False,
)
groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP)
logger.trace("splitting %s into %s groups", tokens.input_ids.shape, groups_count)
groups = []
# np.array_split(tokens.input_ids, groups_count, axis=1)
for i in range(groups_count):
group_start = i * MAX_TOKENS_PER_GROUP
group_end = min(
group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1]
) # or should this be 1?
logger.trace("building group for token slice [%s : %s]", group_start, group_end)
groups.append(tokens.input_ids[:, group_start:group_end])
# encode each chunk
logger.trace("group token shapes: %s", [t.shape for t in groups])
group_embeds = []
for group in groups:
logger.trace("encoding group: %s", group.shape)
2023-03-19 13:17:40 +00:00
text_result = self.text_encoder(input_ids=group.astype(np.int32))
2023-03-19 14:29:06 +00:00
logger.trace(
"text encoder produced %s outputs: %s", len(text_result), text_result
)
2023-03-19 13:17:40 +00:00
last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0:
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
2023-03-19 14:29:06 +00:00
norm_state = layer_norm(
torch.from_numpy(
hidden_states[-skip_clip_states].astype(np.float32)
).detach()
2023-03-19 14:29:06 +00:00
)
logger.trace(
"normalized results after skipping %s layers: %s",
skip_clip_states,
norm_state.shape,
)
group_embeds.append(
norm_state.numpy().astype(hidden_states[-skip_clip_states].dtype)
)
2023-03-19 13:17:40 +00:00
else:
group_embeds.append(last_state)
# concat those embeds
logger.trace("group embeds shape: %s", [t.shape for t in group_embeds])
prompt_embeds = np.concatenate(group_embeds, axis=1)
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input.input_ids.astype(np.int32)
)[0]
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
logger.trace(
"padding negative prompt to match input: %s, %s, %s extra tokens",
tokens.input_ids.shape,
negative_prompt_embeds.shape,
negative_padding,
)
negative_prompt_embeds = np.pad(
negative_prompt_embeds,
[(0, 0), (0, negative_padding), (0, 0)],
mode="constant",
constant_values=0,
)
negative_prompt_embeds = np.repeat(
negative_prompt_embeds, num_images_per_prompt, axis=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
logger.trace("expanded prompt shape: %s", prompt_embeds.shape)
return prompt_embeds
2023-03-16 00:27:29 +00:00
def get_tokens_from_prompt(
prompt: str, pattern: Pattern
2023-03-16 00:27:29 +00:00
) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
remaining_prompt = prompt
tokens = []
next_match = pattern.search(remaining_prompt)
while next_match is not None:
logger.debug("found token in prompt: %s", next_match)
name, weight = next_match.groups()
tokens.append((name, float(weight)))
# remove this match and look for another
2023-03-16 00:27:29 +00:00
remaining_prompt = (
remaining_prompt[: next_match.start()]
+ remaining_prompt[next_match.end() :]
)
next_match = pattern.search(remaining_prompt)
return (remaining_prompt, tokens)
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, LORA_TOKEN)
def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)