diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py new file mode 100644 index 00000000..643f5bf1 --- /dev/null +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -0,0 +1,105 @@ +from typing import Any, Literal + +import numpy as np +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput +from torch import FloatTensor, Tensor + + +class SchedulerPatch: + scheduler: Any + + def __init__(self, scheduler): + self.scheduler = scheduler + + def step( + self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor + ) -> DDIMSchedulerOutput: + result = self.scheduler.step(model_output, timestep, sample) + + white_point = 0 + black_point = 8 + center_line = result.prev_sample.shape[2] // 2 + direction = "horizontal" + + mirrored_latents = mirror_latents( + result.prev_sample, black_point, white_point, center_line, direction + ) + + return DDIMSchedulerOutput( + prev_sample=mirrored_latents, + pred_original_sample=result.pred_original_sample, + ) + + +def mirror_latents( + latents: np.ndarray, + black_point: int, + white_point: int, + center_line: int, + direction: Literal["horizontal", "vertical"], +) -> np.ndarray: + gradient = np.linspace(1, 0, white_point - black_point).astype(np.float32) + gradient = np.pad( + gradient, (black_point, center_line - white_point), mode="constant" + ) + gradient = np.reshape([gradient, np.flip(gradient)], -1) + gradient = np.expand_dims(gradient, (0, 1, 2)) + + if direction == "horizontal": + pad_left = max(0, -center_line) + pad_right = max(0, 2 * center_line - latents.shape[3]) + + # create the symmetrical copies + padded_array = np.pad( + latents, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant" + ) + flipped_array = np.flip(padded_array, axis=3) + + # apply the gradient to both copies + padded_gradiated = np.multiply(padded_array, gradient) + flipped_gradiated = np.multiply(flipped_array, gradient) + + # produce masks + mask = np.ones_like(latents).astype(np.float32) + padded_mask = np.pad( + mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant" + ) + padded_mask += np.multiply(np.ones_like(padded_array), gradient) + + # combine the two copies + result = padded_array + padded_gradiated + flipped_gradiated + result = np.where(padded_mask > 0, result / padded_mask, result) + return result[:, :, :, pad_left : pad_left + latents.shape[3]] + elif direction == "vertical": + pad_top = max(0, -center_line) + pad_bottom = max(0, 2 * center_line - latents.shape[2]) + + # create the symmetrical copies + padded_array = np.pad( + latents, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant" + ) + flipped_array = np.flip(padded_array, axis=2) + + # apply the gradient to both copies + padded_gradiated = np.multiply( + padded_array.transpose(0, 1, 3, 2), gradient + ).transpose(0, 1, 3, 2) + flipped_gradiated = np.multiply( + flipped_array.transpose(0, 1, 3, 2), gradient + ).transpose(0, 1, 3, 2) + + # produce masks + mask = np.ones_like(latents).astype(np.float32) + padded_mask = np.pad( + mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant" + ) + padded_mask += np.multiply( + np.ones_like(padded_array).transpose(0, 1, 3, 2), gradient + ).transpose(0, 1, 3, 2) + + # combine the two copies + result = padded_array + padded_gradiated + flipped_gradiated + result = np.where(padded_mask > 0, result / padded_mask, result) + return flipped_array[:, :, pad_top : pad_top + latents.shape[2], :] + else: + raise ValueError("Invalid direction. Must be 'horizontal' or 'vertical'.") diff --git a/api/onnx_web/diffusers/pipelines/base.py b/api/onnx_web/diffusers/pipelines/base.py new file mode 100644 index 00000000..b4275637 --- /dev/null +++ b/api/onnx_web/diffusers/pipelines/base.py @@ -0,0 +1,268 @@ +from typing import List, Optional, Union + +import numpy as np +from diffusers.configuration_utils import FrozenDict +from diffusers.pipelines.onnx_utils import OnnxRuntimeModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, logging +from transformers import CLIPImageProcessor, CLIPTokenizer + +logger = logging.get_logger(__name__) + + +class OnnxStableDiffusionBasePipeline(DiffusionPipeline): + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="max_length", return_tensors="np" + ).input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids.astype(np.int32) + )[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder( + input_ids=uncond_input.input_ids.astype(np.int32) + )[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat( + negative_prompt_embeds, num_images_per_prompt, axis=0 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) diff --git a/api/tests/test_diffusers/test_scheduler.py b/api/tests/test_diffusers/test_scheduler.py new file mode 100644 index 00000000..30e6d34b --- /dev/null +++ b/api/tests/test_diffusers/test_scheduler.py @@ -0,0 +1,58 @@ +import unittest + +import numpy as np +import torch +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput + +from onnx_web.diffusers.patches.scheduler import SchedulerPatch, mirror_latents + + +class SchedulerPatchTests(unittest.TestCase): + def test_scheduler_step(self): + scheduler = SchedulerPatch(None) + model_output = torch.FloatTensor([1.0, 2.0, 3.0]) + timestep = torch.Tensor([0.1]) + sample = torch.FloatTensor([0.5, 0.6, 0.7]) + output = scheduler.step(model_output, timestep, sample) + assert isinstance(output, DDIMSchedulerOutput) + + def test_mirror_latents_horizontal(self): + latents = np.array( + [ # batch + [ # channels + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], + ], + ] + ) + black_point = 0 + white_point = 1 + center_line = 2 + direction = "horizontal" + mirrored_latents = mirror_latents( + latents, black_point, white_point, center_line, direction + ) + assert np.array_equal(mirrored_latents, latents) + + def test_mirror_latents_vertical(self): + latents = np.array( + [ # batch + [ # channels + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], + ], + ] + ) + black_point = 0 + white_point = 1 + center_line = 3 + direction = "vertical" + mirrored_latents = mirror_latents( + latents, black_point, white_point, center_line, direction + ) + assert np.array_equal( + mirrored_latents, + [ + [ + [[0, 0, 0], [0, 0, 0], [10, 11, 12], [7, 8, 9]], + ] + ], + )