1
0
Fork 0

update forwarding shapes, apply lint

This commit is contained in:
Sean Sube 2023-12-19 22:10:37 -06:00
parent b4262eb777
commit 4a87fb2a31
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 95 additions and 47 deletions

View File

@ -190,7 +190,6 @@ def convert_diffusion_diffusers_cnet(
),
False,
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
@ -214,23 +213,81 @@ def convert_diffusion_diffusers_cnet(
"out_sample"
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"down_block_0": {0: "batch", 2: "height", 3: "width"},
"down_block_1": {0: "batch", 2: "height", 3: "width"},
"down_block_2": {0: "batch", 2: "height", 3: "width"},
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
"sample": {
0: "cnet_sample_batch",
1: "cnet_sample_channels",
2: "cnet_sample_height",
3: "cnet_sample_width",
},
"timestep": {0: "cnet_timestep_batch"},
"encoder_hidden_states": {0: "cnet_ehs_batch", 1: "cnet_ehs_sequence"},
"down_block_0": {
0: "cnet_db0_batch",
2: "cnet_db0_height",
3: "cnet_db0_width",
},
"down_block_1": {
0: "cnet_db1_batch",
2: "cnet_db1_height",
3: "cnet_db1_width",
},
"down_block_2": {
0: "cnet_db2_batch",
2: "cnet_db2_height",
3: "cnet_db2_width",
},
"down_block_3": {
0: "cnet_db3_batch",
2: "cnet_db3_height2",
3: "cnet_db3_width2",
},
"down_block_4": {
0: "cnet_db4_batch",
2: "cnet_db4_height2",
3: "cnet_db4_width2",
},
"down_block_5": {
0: "cnet_db5_batch",
2: "cnet_db5_height2",
3: "cnet_db5_width2",
},
"down_block_6": {
0: "cnet_db6_batch",
2: "cnet_db6_height4",
3: "cnet_db6_width4",
},
"down_block_7": {
0: "cnet_db7_batch",
2: "cnet_db7_height4",
3: "cnet_db7_width4",
},
"down_block_8": {
0: "cnet_db8_batch",
2: "cnet_db8_height4",
3: "cnet_db8_width4",
},
"down_block_9": {
0: "cnet_db9_batch",
2: "cnet_db9_height8",
3: "cnet_db9_width8",
},
"down_block_10": {
0: "cnet_db10_batch",
2: "cnet_db10_height8",
3: "cnet_db10_width8",
},
"down_block_11": {
0: "cnet_db11_batch",
2: "cnet_db11_height8",
3: "cnet_db11_width8",
},
"mid_block_additional_residual": {
0: "cnet_mbar_batch",
2: "cnet_mbar_height8",
3: "cnet_mbar_width8",
},
},
output_path=cnet_path,
opset=conversion.opset,
half=conversion.half,
external_data=True, # UNet is > 2GB, so the weights need to be split

View File

@ -1,21 +1,3 @@
# This file contains a mix of Apache and GPL code and should be treated as a GPL resource
#
# Original attribution:
#
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import Optional, Tuple, Union
@ -27,7 +9,7 @@ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
logger = getLogger(__name__)
class UNet2DConditionModel_Cnet(UNet2DConditionModel):
class UNet2DConditionModel_CNet(UNet2DConditionModel):
def forward(
self,
sample: torch.FloatTensor,
@ -49,15 +31,24 @@ class UNet2DConditionModel_Cnet(UNet2DConditionModel):
return_dict: bool = False,
) -> Union[UNet2DConditionOutput, Tuple]:
down_block_add_res = (
down_block_add_res00, down_block_add_res01, down_block_add_res02,
down_block_add_res03, down_block_add_res04, down_block_add_res05,
down_block_add_res06, down_block_add_res07, down_block_add_res08,
down_block_add_res09, down_block_add_res10, down_block_add_res11)
down_block_add_res00,
down_block_add_res01,
down_block_add_res02,
down_block_add_res03,
down_block_add_res04,
down_block_add_res05,
down_block_add_res06,
down_block_add_res07,
down_block_add_res08,
down_block_add_res09,
down_block_add_res10,
down_block_add_res11,
)
return super().forward(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_add_res,
mid_block_additional_residual=mid_block_additional_residual,
return_dict = return_dict
return_dict=return_dict,
)