From f73dce507ced912e14c9f0d757fbb65b02d17989 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 20 Aug 2023 15:00:08 -0500 Subject: [PATCH] fix(models): update attn parameter name in CNet --- api/onnx_web/models/cnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/models/cnet.py b/api/onnx_web/models/cnet.py index 2c969a1c..1c9a9a02 100644 --- a/api/onnx_web/models/cnet.py +++ b/api/onnx_web/models/cnet.py @@ -287,7 +287,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], + num_attention_heads=attention_head_dim[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, @@ -307,7 +307,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], + num_attention_heads=attention_head_dim[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, @@ -321,7 +321,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], + num_attention_heads=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) @@ -367,7 +367,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], + num_attention_heads=reversed_attention_head_dim[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i],