1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-19 09:29:06 -05:00
parent 641456fd42
commit 243a2d9df6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 10 deletions

View File

@ -137,12 +137,10 @@ def convert_diffusion_diffusers(
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=( model_args=(
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32), text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
None, # attention mask None, # attention mask
None, # position ids None, # position ids
None, # output attentions None, # output attentions
torch.tensor(True).to( torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
device=ctx.training_device, dtype=torch.bool
),
), ),
output_path=output_path / "text_encoder" / "model.onnx", output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"], ordered_input_names=["input_ids"],

View File

@ -3,8 +3,8 @@ from math import ceil
from re import Pattern, compile from re import Pattern, compile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch
import numpy as np import numpy as np
import torch
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__) logger = getLogger(__name__)
@ -78,13 +78,21 @@ def expand_prompt(
logger.trace("encoding group: %s", group.shape) logger.trace("encoding group: %s", group.shape)
text_result = self.text_encoder(input_ids=group.astype(np.int32)) text_result = self.text_encoder(input_ids=group.astype(np.int32))
logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result) logger.trace(
"text encoder produced %s outputs: %s", len(text_result), text_result
)
last_state, _pooled_output, *hidden_states = text_result last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0: if skip_clip_states > 0:
layer_norm = torch.nn.LayerNorm(last_state.shape[2]) layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach()) norm_state = layer_norm(
logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape) torch.from_numpy(hidden_states[-skip_clip_states]).detach()
)
logger.trace(
"normalized results after skipping %s layers: %s",
skip_clip_states,
norm_state.shape,
)
group_embeds.append(norm_state) group_embeds.append(norm_state)
else: else:
group_embeds.append(last_state) group_embeds.append(last_state)