normalize hidden states without using CLIP model class
This commit is contained in:
parent
2ef00599b6
commit
46d1b5636d
|
@ -3,9 +3,9 @@ from math import ceil
|
||||||
from re import Pattern, compile
|
from re import Pattern, compile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
from transformers import CLIPTextModel
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -67,7 +67,6 @@ def expand_prompt(
|
||||||
groups.append(tokens.input_ids[:, group_start:group_end])
|
groups.append(tokens.input_ids[:, group_start:group_end])
|
||||||
|
|
||||||
# encode each chunk
|
# encode each chunk
|
||||||
torch_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
|
|
||||||
logger.trace("group token shapes: %s", [t.shape for t in groups])
|
logger.trace("group token shapes: %s", [t.shape for t in groups])
|
||||||
group_embeds = []
|
group_embeds = []
|
||||||
for group in groups:
|
for group in groups:
|
||||||
|
@ -77,9 +76,9 @@ def expand_prompt(
|
||||||
logger.info("text encoder result: %s", text_result)
|
logger.info("text encoder result: %s", text_result)
|
||||||
|
|
||||||
last_state, _pooled_output, *hidden_states = text_result
|
last_state, _pooled_output, *hidden_states = text_result
|
||||||
if skip_clip_states > 1:
|
if skip_clip_states > 0:
|
||||||
last_state = hidden_states[-skip_clip_states]
|
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||||
norm_state = torch_encoder.text_model.final_layer_norm(torch.from_numpy(last_state).detach())
|
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
|
||||||
logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
||||||
group_embeds.append(norm_state)
|
group_embeds.append(norm_state)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue