apply lint
This commit is contained in:
parent
641456fd42
commit
243a2d9df6
|
@ -137,12 +137,10 @@ def convert_diffusion_diffusers(
|
|||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(
|
||||
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
|
||||
None, # attention mask
|
||||
None, # position ids
|
||||
None, # output attentions
|
||||
torch.tensor(True).to(
|
||||
device=ctx.training_device, dtype=torch.bool
|
||||
),
|
||||
None, # attention mask
|
||||
None, # position ids
|
||||
None, # output attentions
|
||||
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
|
||||
),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
|
|
|
@ -3,8 +3,8 @@ from math import ceil
|
|||
from re import Pattern, compile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -78,13 +78,21 @@ def expand_prompt(
|
|||
logger.trace("encoding group: %s", group.shape)
|
||||
|
||||
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
||||
logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result)
|
||||
logger.trace(
|
||||
"text encoder produced %s outputs: %s", len(text_result), text_result
|
||||
)
|
||||
|
||||
last_state, _pooled_output, *hidden_states = text_result
|
||||
if skip_clip_states > 0:
|
||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
|
||||
logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
||||
norm_state = layer_norm(
|
||||
torch.from_numpy(hidden_states[-skip_clip_states]).detach()
|
||||
)
|
||||
logger.trace(
|
||||
"normalized results after skipping %s layers: %s",
|
||||
skip_clip_states,
|
||||
norm_state.shape,
|
||||
)
|
||||
group_embeds.append(norm_state)
|
||||
else:
|
||||
group_embeds.append(last_state)
|
||||
|
|
Loading…
Reference in New Issue