experimental CLIP skip
This commit is contained in:
parent
3f9789a0a8
commit
2ef00599b6
|
@ -136,11 +136,17 @@ def convert_diffusion_diffusers(
|
||||||
pipeline.text_encoder,
|
pipeline.text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
model_args=(
|
model_args=(
|
||||||
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)
|
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
|
||||||
|
None, # attention mask
|
||||||
|
None, # position ids
|
||||||
|
None, # output attentions
|
||||||
|
torch.tensor(True).to(
|
||||||
|
device=ctx.training_device, dtype=torch.bool
|
||||||
|
),
|
||||||
),
|
),
|
||||||
output_path=output_path / "text_encoder" / "model.onnx",
|
output_path=output_path / "text_encoder" / "model.onnx",
|
||||||
ordered_input_names=["input_ids"],
|
ordered_input_names=["input_ids"],
|
||||||
output_names=["last_hidden_state", "pooler_output"],
|
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"input_ids": {0: "batch", 1: "sequence"},
|
"input_ids": {0: "batch", 1: "sequence"},
|
||||||
},
|
},
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -26,17 +27,20 @@ def expand_prompt_ranges(prompt: str) -> str:
|
||||||
return PATTERN_RANGE.sub(expand_range, prompt)
|
return PATTERN_RANGE.sub(expand_range, prompt)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def expand_prompt(
|
def expand_prompt(
|
||||||
self: OnnxStableDiffusionPipeline,
|
self: OnnxStableDiffusionPipeline,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
num_images_per_prompt: int,
|
num_images_per_prompt: int,
|
||||||
do_classifier_free_guidance: bool,
|
do_classifier_free_guidance: bool,
|
||||||
negative_prompt: Optional[str] = None,
|
negative_prompt: Optional[str] = None,
|
||||||
|
skip_clip_states: Optional[str] = 0,
|
||||||
) -> "np.NDArray":
|
) -> "np.NDArray":
|
||||||
# self provides:
|
# self provides:
|
||||||
# tokenizer: CLIPTokenizer
|
# tokenizer: CLIPTokenizer
|
||||||
# encoder: OnnxRuntimeModel
|
# encoder: OnnxRuntimeModel
|
||||||
|
|
||||||
|
|
||||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||||
prompt = expand_prompt_ranges(prompt)
|
prompt = expand_prompt_ranges(prompt)
|
||||||
|
|
||||||
|
@ -63,12 +67,23 @@ def expand_prompt(
|
||||||
groups.append(tokens.input_ids[:, group_start:group_end])
|
groups.append(tokens.input_ids[:, group_start:group_end])
|
||||||
|
|
||||||
# encode each chunk
|
# encode each chunk
|
||||||
|
torch_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
|
||||||
logger.trace("group token shapes: %s", [t.shape for t in groups])
|
logger.trace("group token shapes: %s", [t.shape for t in groups])
|
||||||
group_embeds = []
|
group_embeds = []
|
||||||
for group in groups:
|
for group in groups:
|
||||||
logger.trace("encoding group: %s", group.shape)
|
logger.trace("encoding group: %s", group.shape)
|
||||||
embeds = self.text_encoder(input_ids=group.astype(np.int32))[0]
|
|
||||||
group_embeds.append(embeds)
|
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
||||||
|
logger.info("text encoder result: %s", text_result)
|
||||||
|
|
||||||
|
last_state, _pooled_output, *hidden_states = text_result
|
||||||
|
if skip_clip_states > 1:
|
||||||
|
last_state = hidden_states[-skip_clip_states]
|
||||||
|
norm_state = torch_encoder.text_model.final_layer_norm(torch.from_numpy(last_state).detach())
|
||||||
|
logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
||||||
|
group_embeds.append(norm_state)
|
||||||
|
else:
|
||||||
|
group_embeds.append(last_state)
|
||||||
|
|
||||||
# concat those embeds
|
# concat those embeds
|
||||||
logger.trace("group embeds shape: %s", [t.shape for t in group_embeds])
|
logger.trace("group embeds shape: %s", [t.shape for t in group_embeds])
|
||||||
|
|
Loading…
Reference in New Issue