fix(models): update attn parameter name in CNet
This commit is contained in:
parent
0fafc71ec8
commit
f73dce507c
|
@ -287,7 +287,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
||||||
resnet_act_fn=act_fn,
|
resnet_act_fn=act_fn,
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
attn_num_head_channels=attention_head_dim[i],
|
num_attention_heads=attention_head_dim[i],
|
||||||
downsample_padding=downsample_padding,
|
downsample_padding=downsample_padding,
|
||||||
dual_cross_attention=dual_cross_attention,
|
dual_cross_attention=dual_cross_attention,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
|
@ -307,7 +307,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
||||||
output_scale_factor=mid_block_scale_factor,
|
output_scale_factor=mid_block_scale_factor,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
attn_num_head_channels=attention_head_dim[-1],
|
num_attention_heads=attention_head_dim[-1],
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
dual_cross_attention=dual_cross_attention,
|
dual_cross_attention=dual_cross_attention,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
|
@ -321,7 +321,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
||||||
resnet_act_fn=act_fn,
|
resnet_act_fn=act_fn,
|
||||||
output_scale_factor=mid_block_scale_factor,
|
output_scale_factor=mid_block_scale_factor,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
attn_num_head_channels=attention_head_dim[-1],
|
num_attention_heads=attention_head_dim[-1],
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
)
|
)
|
||||||
|
@ -367,7 +367,7 @@ class UNet2DConditionModel_CNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersM
|
||||||
resnet_act_fn=act_fn,
|
resnet_act_fn=act_fn,
|
||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
num_attention_heads=reversed_attention_head_dim[i],
|
||||||
dual_cross_attention=dual_cross_attention,
|
dual_cross_attention=dual_cross_attention,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
only_cross_attention=only_cross_attention[i],
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
|
Loading…
Reference in New Issue