diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 29ff91bb..b844920a 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -23,6 +23,8 @@ from diffusers import ( ) from transformers import CLIPTokenizer +from onnx_web.diffusers.utils import expand_prompt + try: from diffusers import DEISMultistepScheduler except ImportError: @@ -230,6 +232,10 @@ def load_pipeline( if device is not None and hasattr(pipe, "to"): pipe = pipe.to(device.torch_str()) + # monkey-patch pipeline + if not lpw: + pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) + server.cache.set("diffusion", pipe_key, pipe) server.cache.set("scheduler", scheduler_key, components["scheduler"]) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index d5d782cd..5bd05e43 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -1,2 +1,115 @@ -def expand_prompt(prompt: str) -> str: - return prompt +from logging import getLogger +from math import ceil +from typing import List, Optional + +import numpy as np +from diffusers import OnnxStableDiffusionPipeline + +logger = getLogger(__name__) + +MAX_TOKENS_PER_GROUP = 77 + + +def expand_prompt( + self: OnnxStableDiffusionPipeline, + prompt: str, + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[str] = None, +) -> "np.NDArray": + # self provides: + # tokenizer: CLIPTokenizer + # encoder: OnnxRuntimeModel + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # split prompt into 75 token chunks + tokens = self.tokenizer( + prompt, + padding="max_length", + return_tensors="np", + max_length=self.tokenizer.model_max_length, + truncation=False, + ) + + groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP) + logger.info("splitting %s into %s groups", tokens.input_ids.shape, groups_count) + + groups = [] + # np.array_split(tokens.input_ids, groups_count, axis=1) + for i in range(groups_count): + group_start = i * MAX_TOKENS_PER_GROUP + group_end = min( + group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1] + ) # or should this be 1? + logger.info("building group for token slice [%s : %s]", group_start, group_end) + groups.append(tokens.input_ids[:, group_start:group_end]) + + # encode each chunk + logger.info("group token shapes: %s", [t.shape for t in groups]) + group_embeds = [] + for group in groups: + logger.info("encoding group: %s", group.shape) + embeds = self.text_encoder(input_ids=group.astype(np.int32))[0] + group_embeds.append(embeds) + + # concat those embeds + logger.info("group embeds shape: %s", [t.shape for t in group_embeds]) + prompt_embeds = np.concatenate(group_embeds, axis=1) + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder( + input_ids=uncond_input.input_ids.astype(np.int32) + )[0] + negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] + logger.info( + "padding negative prompt to match input: %s, %s, %s extra tokens", + tokens.input_ids.shape, + negative_prompt_embeds.shape, + negative_padding, + ) + negative_prompt_embeds = np.pad( + negative_prompt_embeds, + [(0, 0), (0, negative_padding), (0, 0)], + mode="constant", + constant_values=0, + ) + negative_prompt_embeds = np.repeat( + negative_prompt_embeds, num_images_per_prompt, axis=0 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + logger.info("expanded prompt shape: %s", prompt_embeds.shape) + return prompt_embeds diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 87c33f67..3afefe65 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -54,7 +54,10 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("worker got keyboard interrupt") exit(EXIT_INTERRUPT) except ValueError as e: - logger.info("value error in worker, exiting: %s", e) + logger.info( + "value error in worker, exiting: %s", + format_exception(type(e), e, e.__traceback__), + ) exit(EXIT_ERROR) except Exception as e: if "Failed to allocate memory" in str(e):