update forwarding shapes, apply lint
This commit is contained in:
parent
b4262eb777
commit
4a87fb2a31
|
@ -190,7 +190,6 @@ def convert_diffusion_diffusers_cnet(
|
|||
),
|
||||
False,
|
||||
),
|
||||
output_path=cnet_path,
|
||||
ordered_input_names=[
|
||||
"sample",
|
||||
"timestep",
|
||||
|
@ -214,23 +213,81 @@ def convert_diffusion_diffusers_cnet(
|
|||
"out_sample"
|
||||
], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
"down_block_0": {0: "batch", 2: "height", 3: "width"},
|
||||
"down_block_1": {0: "batch", 2: "height", 3: "width"},
|
||||
"down_block_2": {0: "batch", 2: "height", 3: "width"},
|
||||
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
|
||||
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
|
||||
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
|
||||
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
|
||||
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
|
||||
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
|
||||
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
|
||||
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
|
||||
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
|
||||
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
|
||||
"sample": {
|
||||
0: "cnet_sample_batch",
|
||||
1: "cnet_sample_channels",
|
||||
2: "cnet_sample_height",
|
||||
3: "cnet_sample_width",
|
||||
},
|
||||
"timestep": {0: "cnet_timestep_batch"},
|
||||
"encoder_hidden_states": {0: "cnet_ehs_batch", 1: "cnet_ehs_sequence"},
|
||||
"down_block_0": {
|
||||
0: "cnet_db0_batch",
|
||||
2: "cnet_db0_height",
|
||||
3: "cnet_db0_width",
|
||||
},
|
||||
"down_block_1": {
|
||||
0: "cnet_db1_batch",
|
||||
2: "cnet_db1_height",
|
||||
3: "cnet_db1_width",
|
||||
},
|
||||
"down_block_2": {
|
||||
0: "cnet_db2_batch",
|
||||
2: "cnet_db2_height",
|
||||
3: "cnet_db2_width",
|
||||
},
|
||||
"down_block_3": {
|
||||
0: "cnet_db3_batch",
|
||||
2: "cnet_db3_height2",
|
||||
3: "cnet_db3_width2",
|
||||
},
|
||||
"down_block_4": {
|
||||
0: "cnet_db4_batch",
|
||||
2: "cnet_db4_height2",
|
||||
3: "cnet_db4_width2",
|
||||
},
|
||||
"down_block_5": {
|
||||
0: "cnet_db5_batch",
|
||||
2: "cnet_db5_height2",
|
||||
3: "cnet_db5_width2",
|
||||
},
|
||||
"down_block_6": {
|
||||
0: "cnet_db6_batch",
|
||||
2: "cnet_db6_height4",
|
||||
3: "cnet_db6_width4",
|
||||
},
|
||||
"down_block_7": {
|
||||
0: "cnet_db7_batch",
|
||||
2: "cnet_db7_height4",
|
||||
3: "cnet_db7_width4",
|
||||
},
|
||||
"down_block_8": {
|
||||
0: "cnet_db8_batch",
|
||||
2: "cnet_db8_height4",
|
||||
3: "cnet_db8_width4",
|
||||
},
|
||||
"down_block_9": {
|
||||
0: "cnet_db9_batch",
|
||||
2: "cnet_db9_height8",
|
||||
3: "cnet_db9_width8",
|
||||
},
|
||||
"down_block_10": {
|
||||
0: "cnet_db10_batch",
|
||||
2: "cnet_db10_height8",
|
||||
3: "cnet_db10_width8",
|
||||
},
|
||||
"down_block_11": {
|
||||
0: "cnet_db11_batch",
|
||||
2: "cnet_db11_height8",
|
||||
3: "cnet_db11_width8",
|
||||
},
|
||||
"mid_block_additional_residual": {
|
||||
0: "cnet_mbar_batch",
|
||||
2: "cnet_mbar_height8",
|
||||
3: "cnet_mbar_width8",
|
||||
},
|
||||
},
|
||||
output_path=cnet_path,
|
||||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
external_data=True, # UNet is > 2GB, so the weights need to be split
|
||||
|
|
|
@ -1,21 +1,3 @@
|
|||
# This file contains a mix of Apache and GPL code and should be treated as a GPL resource
|
||||
#
|
||||
# Original attribution:
|
||||
#
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
@ -27,7 +9,7 @@ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UNet2DConditionModel_Cnet(UNet2DConditionModel):
|
||||
class UNet2DConditionModel_CNet(UNet2DConditionModel):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
@ -49,15 +31,24 @@ class UNet2DConditionModel_Cnet(UNet2DConditionModel):
|
|||
return_dict: bool = False,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
down_block_add_res = (
|
||||
down_block_add_res00, down_block_add_res01, down_block_add_res02,
|
||||
down_block_add_res03, down_block_add_res04, down_block_add_res05,
|
||||
down_block_add_res06, down_block_add_res07, down_block_add_res08,
|
||||
down_block_add_res09, down_block_add_res10, down_block_add_res11)
|
||||
return super().forward(
|
||||
sample = sample,
|
||||
timestep = timestep,
|
||||
encoder_hidden_states = encoder_hidden_states,
|
||||
down_block_additional_residuals = down_block_add_res,
|
||||
mid_block_additional_residual = mid_block_additional_residual,
|
||||
return_dict = return_dict
|
||||
down_block_add_res00,
|
||||
down_block_add_res01,
|
||||
down_block_add_res02,
|
||||
down_block_add_res03,
|
||||
down_block_add_res04,
|
||||
down_block_add_res05,
|
||||
down_block_add_res06,
|
||||
down_block_add_res07,
|
||||
down_block_add_res08,
|
||||
down_block_add_res09,
|
||||
down_block_add_res10,
|
||||
down_block_add_res11,
|
||||
)
|
||||
return super().forward(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
down_block_additional_residuals=down_block_add_res,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue