1
0
Fork 0

normalize hidden states without using CLIP model class

This commit is contained in:
Sean Sube 2023-03-19 08:40:06 -05:00
parent 2ef00599b6
commit 46d1b5636d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 4 additions and 5 deletions

View File

@ -3,9 +3,9 @@ from math import ceil
from re import Pattern, compile from re import Pattern, compile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch
import numpy as np import numpy as np
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
from transformers import CLIPTextModel
logger = getLogger(__name__) logger = getLogger(__name__)
@ -67,7 +67,6 @@ def expand_prompt(
groups.append(tokens.input_ids[:, group_start:group_end]) groups.append(tokens.input_ids[:, group_start:group_end])
# encode each chunk # encode each chunk
torch_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
logger.trace("group token shapes: %s", [t.shape for t in groups]) logger.trace("group token shapes: %s", [t.shape for t in groups])
group_embeds = [] group_embeds = []
for group in groups: for group in groups:
@ -77,9 +76,9 @@ def expand_prompt(
logger.info("text encoder result: %s", text_result) logger.info("text encoder result: %s", text_result)
last_state, _pooled_output, *hidden_states = text_result last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 1: if skip_clip_states > 0:
last_state = hidden_states[-skip_clip_states] layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = torch_encoder.text_model.final_layer_norm(torch.from_numpy(last_state).detach()) norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape) logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
group_embeds.append(norm_state) group_embeds.append(norm_state)
else: else: