diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index d1d3e591..090335c5 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -137,12 +137,10 @@ def convert_diffusion_diffusers( # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files model_args=( text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32), - None, # attention mask - None, # position ids - None, # output attentions - torch.tensor(True).to( - device=ctx.training_device, dtype=torch.bool - ), + None, # attention mask + None, # position ids + None, # output attentions + torch.tensor(True).to(device=ctx.training_device, dtype=torch.bool), ), output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 7c0cb2a5..958f672a 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -3,8 +3,8 @@ from math import ceil from re import Pattern, compile from typing import List, Optional, Tuple -import torch import numpy as np +import torch from diffusers import OnnxStableDiffusionPipeline logger = getLogger(__name__) @@ -78,13 +78,21 @@ def expand_prompt( logger.trace("encoding group: %s", group.shape) text_result = self.text_encoder(input_ids=group.astype(np.int32)) - logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result) + logger.trace( + "text encoder produced %s outputs: %s", len(text_result), text_result + ) last_state, _pooled_output, *hidden_states = text_result if skip_clip_states > 0: layer_norm = torch.nn.LayerNorm(last_state.shape[2]) - norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach()) - logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape) + norm_state = layer_norm( + torch.from_numpy(hidden_states[-skip_clip_states]).detach() + ) + logger.trace( + "normalized results after skipping %s layers: %s", + skip_clip_states, + norm_state.shape, + ) group_embeds.append(norm_state) else: group_embeds.append(last_state)