diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 64b67830..d1d3e591 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -136,11 +136,17 @@ def convert_diffusion_diffusers( pipeline.text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files model_args=( - text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32) + text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32), + None, # attention mask + None, # position ids + None, # output attentions + torch.tensor(True).to( + device=ctx.training_device, dtype=torch.bool + ), ), output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], - output_names=["last_hidden_state", "pooler_output"], + output_names=["last_hidden_state", "pooler_output", "hidden_states"], dynamic_axes={ "input_ids": {0: "batch", 1: "sequence"}, }, diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index a1360138..45d9157e 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple import numpy as np from diffusers import OnnxStableDiffusionPipeline +from transformers import CLIPTextModel logger = getLogger(__name__) @@ -26,17 +27,20 @@ def expand_prompt_ranges(prompt: str) -> str: return PATTERN_RANGE.sub(expand_range, prompt) +@torch.no_grad() def expand_prompt( self: OnnxStableDiffusionPipeline, prompt: str, num_images_per_prompt: int, do_classifier_free_guidance: bool, negative_prompt: Optional[str] = None, + skip_clip_states: Optional[str] = 0, ) -> "np.NDArray": # self provides: # tokenizer: CLIPTokenizer # encoder: OnnxRuntimeModel + batch_size = len(prompt) if isinstance(prompt, list) else 1 prompt = expand_prompt_ranges(prompt) @@ -63,12 +67,23 @@ def expand_prompt( groups.append(tokens.input_ids[:, group_start:group_end]) # encode each chunk + torch_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") logger.trace("group token shapes: %s", [t.shape for t in groups]) group_embeds = [] for group in groups: logger.trace("encoding group: %s", group.shape) - embeds = self.text_encoder(input_ids=group.astype(np.int32))[0] - group_embeds.append(embeds) + + text_result = self.text_encoder(input_ids=group.astype(np.int32)) + logger.info("text encoder result: %s", text_result) + + last_state, _pooled_output, *hidden_states = text_result + if skip_clip_states > 1: + last_state = hidden_states[-skip_clip_states] + norm_state = torch_encoder.text_model.final_layer_norm(torch.from_numpy(last_state).detach()) + logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape) + group_embeds.append(norm_state) + else: + group_embeds.append(last_state) # concat those embeds logger.trace("group embeds shape: %s", [t.shape for t in group_embeds])