update forwarding shapes, apply lint
This commit is contained in:
parent
b4262eb777
commit
4a87fb2a31
|
@ -190,7 +190,6 @@ def convert_diffusion_diffusers_cnet(
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=cnet_path,
|
|
||||||
ordered_input_names=[
|
ordered_input_names=[
|
||||||
"sample",
|
"sample",
|
||||||
"timestep",
|
"timestep",
|
||||||
|
@ -214,23 +213,81 @@ def convert_diffusion_diffusers_cnet(
|
||||||
"out_sample"
|
"out_sample"
|
||||||
], # has to be different from "sample" for correct tracing
|
], # has to be different from "sample" for correct tracing
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"sample": {
|
||||||
"timestep": {0: "batch"},
|
0: "cnet_sample_batch",
|
||||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
1: "cnet_sample_channels",
|
||||||
"down_block_0": {0: "batch", 2: "height", 3: "width"},
|
2: "cnet_sample_height",
|
||||||
"down_block_1": {0: "batch", 2: "height", 3: "width"},
|
3: "cnet_sample_width",
|
||||||
"down_block_2": {0: "batch", 2: "height", 3: "width"},
|
|
||||||
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
|
|
||||||
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
|
|
||||||
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
|
|
||||||
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
|
|
||||||
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
|
|
||||||
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
|
|
||||||
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
|
|
||||||
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
|
|
||||||
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
|
|
||||||
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
|
|
||||||
},
|
},
|
||||||
|
"timestep": {0: "cnet_timestep_batch"},
|
||||||
|
"encoder_hidden_states": {0: "cnet_ehs_batch", 1: "cnet_ehs_sequence"},
|
||||||
|
"down_block_0": {
|
||||||
|
0: "cnet_db0_batch",
|
||||||
|
2: "cnet_db0_height",
|
||||||
|
3: "cnet_db0_width",
|
||||||
|
},
|
||||||
|
"down_block_1": {
|
||||||
|
0: "cnet_db1_batch",
|
||||||
|
2: "cnet_db1_height",
|
||||||
|
3: "cnet_db1_width",
|
||||||
|
},
|
||||||
|
"down_block_2": {
|
||||||
|
0: "cnet_db2_batch",
|
||||||
|
2: "cnet_db2_height",
|
||||||
|
3: "cnet_db2_width",
|
||||||
|
},
|
||||||
|
"down_block_3": {
|
||||||
|
0: "cnet_db3_batch",
|
||||||
|
2: "cnet_db3_height2",
|
||||||
|
3: "cnet_db3_width2",
|
||||||
|
},
|
||||||
|
"down_block_4": {
|
||||||
|
0: "cnet_db4_batch",
|
||||||
|
2: "cnet_db4_height2",
|
||||||
|
3: "cnet_db4_width2",
|
||||||
|
},
|
||||||
|
"down_block_5": {
|
||||||
|
0: "cnet_db5_batch",
|
||||||
|
2: "cnet_db5_height2",
|
||||||
|
3: "cnet_db5_width2",
|
||||||
|
},
|
||||||
|
"down_block_6": {
|
||||||
|
0: "cnet_db6_batch",
|
||||||
|
2: "cnet_db6_height4",
|
||||||
|
3: "cnet_db6_width4",
|
||||||
|
},
|
||||||
|
"down_block_7": {
|
||||||
|
0: "cnet_db7_batch",
|
||||||
|
2: "cnet_db7_height4",
|
||||||
|
3: "cnet_db7_width4",
|
||||||
|
},
|
||||||
|
"down_block_8": {
|
||||||
|
0: "cnet_db8_batch",
|
||||||
|
2: "cnet_db8_height4",
|
||||||
|
3: "cnet_db8_width4",
|
||||||
|
},
|
||||||
|
"down_block_9": {
|
||||||
|
0: "cnet_db9_batch",
|
||||||
|
2: "cnet_db9_height8",
|
||||||
|
3: "cnet_db9_width8",
|
||||||
|
},
|
||||||
|
"down_block_10": {
|
||||||
|
0: "cnet_db10_batch",
|
||||||
|
2: "cnet_db10_height8",
|
||||||
|
3: "cnet_db10_width8",
|
||||||
|
},
|
||||||
|
"down_block_11": {
|
||||||
|
0: "cnet_db11_batch",
|
||||||
|
2: "cnet_db11_height8",
|
||||||
|
3: "cnet_db11_width8",
|
||||||
|
},
|
||||||
|
"mid_block_additional_residual": {
|
||||||
|
0: "cnet_mbar_batch",
|
||||||
|
2: "cnet_mbar_height8",
|
||||||
|
3: "cnet_mbar_width8",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
output_path=cnet_path,
|
||||||
opset=conversion.opset,
|
opset=conversion.opset,
|
||||||
half=conversion.half,
|
half=conversion.half,
|
||||||
external_data=True, # UNet is > 2GB, so the weights need to be split
|
external_data=True, # UNet is > 2GB, so the weights need to be split
|
||||||
|
|
|
@ -1,21 +1,3 @@
|
||||||
# This file contains a mix of Apache and GPL code and should be treated as a GPL resource
|
|
||||||
#
|
|
||||||
# Original attribution:
|
|
||||||
#
|
|
||||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -27,7 +9,7 @@ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UNet2DConditionModel_Cnet(UNet2DConditionModel):
|
class UNet2DConditionModel_CNet(UNet2DConditionModel):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
|
@ -49,15 +31,24 @@ class UNet2DConditionModel_Cnet(UNet2DConditionModel):
|
||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||||
down_block_add_res = (
|
down_block_add_res = (
|
||||||
down_block_add_res00, down_block_add_res01, down_block_add_res02,
|
down_block_add_res00,
|
||||||
down_block_add_res03, down_block_add_res04, down_block_add_res05,
|
down_block_add_res01,
|
||||||
down_block_add_res06, down_block_add_res07, down_block_add_res08,
|
down_block_add_res02,
|
||||||
down_block_add_res09, down_block_add_res10, down_block_add_res11)
|
down_block_add_res03,
|
||||||
|
down_block_add_res04,
|
||||||
|
down_block_add_res05,
|
||||||
|
down_block_add_res06,
|
||||||
|
down_block_add_res07,
|
||||||
|
down_block_add_res08,
|
||||||
|
down_block_add_res09,
|
||||||
|
down_block_add_res10,
|
||||||
|
down_block_add_res11,
|
||||||
|
)
|
||||||
return super().forward(
|
return super().forward(
|
||||||
sample=sample,
|
sample=sample,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
down_block_additional_residuals=down_block_add_res,
|
down_block_additional_residuals=down_block_add_res,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
return_dict = return_dict
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue