1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-19 09:29:06 -05:00
parent 641456fd42
commit 243a2d9df6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 10 deletions

View File

@ -137,12 +137,10 @@ def convert_diffusion_diffusers(
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32),
None, # attention mask
None, # position ids
None, # output attentions
torch.tensor(True).to(
device=ctx.training_device, dtype=torch.bool
),
None, # attention mask
None, # position ids
None, # output attentions
torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool),
),
output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"],

View File

@ -3,8 +3,8 @@ from math import ceil
from re import Pattern, compile
from typing import List, Optional, Tuple
import torch
import numpy as np
import torch
from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__)
@ -78,13 +78,21 @@ def expand_prompt(
logger.trace("encoding group: %s", group.shape)
text_result = self.text_encoder(input_ids=group.astype(np.int32))
logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result)
logger.trace(
"text encoder produced %s outputs: %s", len(text_result), text_result
)
last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0:
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
norm_state = layer_norm(
torch.from_numpy(hidden_states[-skip_clip_states]).detach()
)
logger.trace(
"normalized results after skipping %s layers: %s",
skip_clip_states,
norm_state.shape,
)
group_embeds.append(norm_state)
else:
group_embeds.append(last_state)