feat(api): add support for extremely long prompts
This commit is contained in:
parent
c45915e558
commit
66c42485cb
|
@ -23,6 +23,8 @@ from diffusers import (
|
|||
)
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from onnx_web.diffusers.utils import expand_prompt
|
||||
|
||||
try:
|
||||
from diffusers import DEISMultistepScheduler
|
||||
except ImportError:
|
||||
|
@ -230,6 +232,10 @@ def load_pipeline(
|
|||
if device is not None and hasattr(pipe, "to"):
|
||||
pipe = pipe.to(device.torch_str())
|
||||
|
||||
# monkey-patch pipeline
|
||||
if not lpw:
|
||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||
|
||||
server.cache.set("diffusion", pipe_key, pipe)
|
||||
server.cache.set("scheduler", scheduler_key, components["scheduler"])
|
||||
|
||||
|
|
|
@ -1,2 +1,115 @@
|
|||
def expand_prompt(prompt: str) -> str:
|
||||
return prompt
|
||||
from logging import getLogger
|
||||
from math import ceil
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
MAX_TOKENS_PER_GROUP = 77
|
||||
|
||||
|
||||
def expand_prompt(
|
||||
self: OnnxStableDiffusionPipeline,
|
||||
prompt: str,
|
||||
num_images_per_prompt: int,
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[str] = None,
|
||||
) -> "np.NDArray":
|
||||
# self provides:
|
||||
# tokenizer: CLIPTokenizer
|
||||
# encoder: OnnxRuntimeModel
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# split prompt into 75 token chunks
|
||||
tokens = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
return_tensors="np",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=False,
|
||||
)
|
||||
|
||||
groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP)
|
||||
logger.info("splitting %s into %s groups", tokens.input_ids.shape, groups_count)
|
||||
|
||||
groups = []
|
||||
# np.array_split(tokens.input_ids, groups_count, axis=1)
|
||||
for i in range(groups_count):
|
||||
group_start = i * MAX_TOKENS_PER_GROUP
|
||||
group_end = min(
|
||||
group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1]
|
||||
) # or should this be 1?
|
||||
logger.info("building group for token slice [%s : %s]", group_start, group_end)
|
||||
groups.append(tokens.input_ids[:, group_start:group_end])
|
||||
|
||||
# encode each chunk
|
||||
logger.info("group token shapes: %s", [t.shape for t in groups])
|
||||
group_embeds = []
|
||||
for group in groups:
|
||||
logger.info("encoding group: %s", group.shape)
|
||||
embeds = self.text_encoder(input_ids=group.astype(np.int32))[0]
|
||||
group_embeds.append(embeds)
|
||||
|
||||
# concat those embeds
|
||||
logger.info("group embeds shape: %s", [t.shape for t in group_embeds])
|
||||
prompt_embeds = np.concatenate(group_embeds, axis=1)
|
||||
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.astype(np.int32)
|
||||
)[0]
|
||||
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
|
||||
logger.info(
|
||||
"padding negative prompt to match input: %s, %s, %s extra tokens",
|
||||
tokens.input_ids.shape,
|
||||
negative_prompt_embeds.shape,
|
||||
negative_padding,
|
||||
)
|
||||
negative_prompt_embeds = np.pad(
|
||||
negative_prompt_embeds,
|
||||
[(0, 0), (0, negative_padding), (0, 0)],
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
negative_prompt_embeds = np.repeat(
|
||||
negative_prompt_embeds, num_images_per_prompt, axis=0
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
logger.info("expanded prompt shape: %s", prompt_embeds.shape)
|
||||
return prompt_embeds
|
||||
|
|
|
@ -54,7 +54,10 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
|||
logger.info("worker got keyboard interrupt")
|
||||
exit(EXIT_INTERRUPT)
|
||||
except ValueError as e:
|
||||
logger.info("value error in worker, exiting: %s", e)
|
||||
logger.info(
|
||||
"value error in worker, exiting: %s",
|
||||
format_exception(type(e), e, e.__traceback__),
|
||||
)
|
||||
exit(EXIT_ERROR)
|
||||
except Exception as e:
|
||||
if "Failed to allocate memory" in str(e):
|
||||
|
|
Loading…
Reference in New Issue