From 46d1b5636d48f49f3a8bec80618d4e5ed0f9df8b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 08:40:06 -0500 Subject: [PATCH] normalize hidden states without using CLIP model class --- api/onnx_web/diffusers/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 45d9157e..0b8c6d47 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -3,9 +3,9 @@ from math import ceil from re import Pattern, compile from typing import List, Optional, Tuple +import torch import numpy as np from diffusers import OnnxStableDiffusionPipeline -from transformers import CLIPTextModel logger = getLogger(__name__) @@ -67,7 +67,6 @@ def expand_prompt( groups.append(tokens.input_ids[:, group_start:group_end]) # encode each chunk - torch_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") logger.trace("group token shapes: %s", [t.shape for t in groups]) group_embeds = [] for group in groups: @@ -77,9 +76,9 @@ def expand_prompt( logger.info("text encoder result: %s", text_result) last_state, _pooled_output, *hidden_states = text_result - if skip_clip_states > 1: - last_state = hidden_states[-skip_clip_states] - norm_state = torch_encoder.text_model.final_layer_norm(torch.from_numpy(last_state).detach()) + if skip_clip_states > 0: + layer_norm = torch.nn.LayerNorm(last_state.shape[2]) + norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach()) logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape) group_embeds.append(norm_state) else: