fix(models): update attn parameter name in CNet
This commit is contained in:
parent
0fafc71ec8
commit
f73dce507c
|
@ -287,7 +287,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
|||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
num_attention_heads=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
|
@ -307,7 +307,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
|||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
num_attention_heads=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
|
@ -321,7 +321,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
|||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
num_attention_heads=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
|
@ -367,7 +367,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
|||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
num_attention_heads=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
|
|
Loading…
Reference in New Issue