1
0
Fork 0

fix(models): update attn parameter name in CNet

This commit is contained in:
Sean Sube 2023-08-20 15:00:08 -05:00
parent 0fafc71ec8
commit f73dce507c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 4 additions and 4 deletions

View File

@ -287,7 +287,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
@ -307,7 +307,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
@ -321,7 +321,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
@ -367,7 +367,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],