apply lint
This commit is contained in:
parent
641456fd42
commit
243a2d9df6
|
@ -137,12 +137,10 @@ def convert_diffusion_diffusers(
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
model_args=(
|
model_args=(
|
||||||
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
|
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
|
||||||
None, # attention mask
|
None, # attention mask
|
||||||
None, # position ids
|
None, # position ids
|
||||||
None, # output attentions
|
None, # output attentions
|
||||||
torch.tensor(True).to(
|
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
|
||||||
device=ctx.training_device, dtype=torch.bool
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
output_path=output_path / "text_encoder" / "model.onnx",
|
output_path=output_path / "text_encoder" / "model.onnx",
|
||||||
ordered_input_names=["input_ids"],
|
ordered_input_names=["input_ids"],
|
||||||
|
|
|
@ -3,8 +3,8 @@ from math import ceil
|
||||||
from re import Pattern, compile
|
from re import Pattern, compile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -78,13 +78,21 @@ def expand_prompt(
|
||||||
logger.trace("encoding group: %s", group.shape)
|
logger.trace("encoding group: %s", group.shape)
|
||||||
|
|
||||||
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
||||||
logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result)
|
logger.trace(
|
||||||
|
"text encoder produced %s outputs: %s", len(text_result), text_result
|
||||||
|
)
|
||||||
|
|
||||||
last_state, _pooled_output, *hidden_states = text_result
|
last_state, _pooled_output, *hidden_states = text_result
|
||||||
if skip_clip_states > 0:
|
if skip_clip_states > 0:
|
||||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||||
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
|
norm_state = layer_norm(
|
||||||
logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
torch.from_numpy(hidden_states[-skip_clip_states]).detach()
|
||||||
|
)
|
||||||
|
logger.trace(
|
||||||
|
"normalized results after skipping %s layers: %s",
|
||||||
|
skip_clip_states,
|
||||||
|
norm_state.shape,
|
||||||
|
)
|
||||||
group_embeds.append(norm_state)
|
group_embeds.append(norm_state)
|
||||||
else:
|
else:
|
||||||
group_embeds.append(last_state)
|
group_embeds.append(last_state)
|
||||||
|
|
Loading…
Reference in New Issue