diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 2f1899a5..07b18d10 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -190,7 +190,6 @@ def convert_diffusion_diffusers_cnet( ), False, ), - output_path=cnet_path, ordered_input_names=[ "sample", "timestep", @@ -214,23 +213,81 @@ def convert_diffusion_diffusers_cnet( "out_sample" ], # has to be different from "sample" for correct tracing dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "timestep": {0: "batch"}, - "encoder_hidden_states": {0: "batch", 1: "sequence"}, - "down_block_0": {0: "batch", 2: "height", 3: "width"}, - "down_block_1": {0: "batch", 2: "height", 3: "width"}, - "down_block_2": {0: "batch", 2: "height", 3: "width"}, - "down_block_3": {0: "batch", 2: "height2", 3: "width2"}, - "down_block_4": {0: "batch", 2: "height2", 3: "width2"}, - "down_block_5": {0: "batch", 2: "height2", 3: "width2"}, - "down_block_6": {0: "batch", 2: "height4", 3: "width4"}, - "down_block_7": {0: "batch", 2: "height4", 3: "width4"}, - "down_block_8": {0: "batch", 2: "height4", 3: "width4"}, - "down_block_9": {0: "batch", 2: "height8", 3: "width8"}, - "down_block_10": {0: "batch", 2: "height8", 3: "width8"}, - "down_block_11": {0: "batch", 2: "height8", 3: "width8"}, - "mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"}, + "sample": { + 0: "cnet_sample_batch", + 1: "cnet_sample_channels", + 2: "cnet_sample_height", + 3: "cnet_sample_width", + }, + "timestep": {0: "cnet_timestep_batch"}, + "encoder_hidden_states": {0: "cnet_ehs_batch", 1: "cnet_ehs_sequence"}, + "down_block_0": { + 0: "cnet_db0_batch", + 2: "cnet_db0_height", + 3: "cnet_db0_width", + }, + "down_block_1": { + 0: "cnet_db1_batch", + 2: "cnet_db1_height", + 3: "cnet_db1_width", + }, + "down_block_2": { + 0: "cnet_db2_batch", + 2: "cnet_db2_height", + 3: "cnet_db2_width", + }, + "down_block_3": { + 0: "cnet_db3_batch", + 2: "cnet_db3_height2", + 3: "cnet_db3_width2", + }, + "down_block_4": { + 0: "cnet_db4_batch", + 2: "cnet_db4_height2", + 3: "cnet_db4_width2", + }, + "down_block_5": { + 0: "cnet_db5_batch", + 2: "cnet_db5_height2", + 3: "cnet_db5_width2", + }, + "down_block_6": { + 0: "cnet_db6_batch", + 2: "cnet_db6_height4", + 3: "cnet_db6_width4", + }, + "down_block_7": { + 0: "cnet_db7_batch", + 2: "cnet_db7_height4", + 3: "cnet_db7_width4", + }, + "down_block_8": { + 0: "cnet_db8_batch", + 2: "cnet_db8_height4", + 3: "cnet_db8_width4", + }, + "down_block_9": { + 0: "cnet_db9_batch", + 2: "cnet_db9_height8", + 3: "cnet_db9_width8", + }, + "down_block_10": { + 0: "cnet_db10_batch", + 2: "cnet_db10_height8", + 3: "cnet_db10_width8", + }, + "down_block_11": { + 0: "cnet_db11_batch", + 2: "cnet_db11_height8", + 3: "cnet_db11_width8", + }, + "mid_block_additional_residual": { + 0: "cnet_mbar_batch", + 2: "cnet_mbar_height8", + 3: "cnet_mbar_width8", + }, }, + output_path=cnet_path, opset=conversion.opset, half=conversion.half, external_data=True, # UNet is > 2GB, so the weights need to be split diff --git a/api/onnx_web/models/cnet.py b/api/onnx_web/models/cnet.py index 6243587d..15a930f0 100644 --- a/api/onnx_web/models/cnet.py +++ b/api/onnx_web/models/cnet.py @@ -1,21 +1,3 @@ -# This file contains a mix of Apache and GPL code and should be treated as a GPL resource -# -# Original attribution: -# -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from logging import getLogger from typing import Optional, Tuple, Union @@ -27,7 +9,7 @@ from diffusers.models.unet_2d_condition import UNet2DConditionOutput logger = getLogger(__name__) -class UNet2DConditionModel_Cnet(UNet2DConditionModel): +class UNet2DConditionModel_CNet(UNet2DConditionModel): def forward( self, sample: torch.FloatTensor, @@ -49,15 +31,24 @@ class UNet2DConditionModel_Cnet(UNet2DConditionModel): return_dict: bool = False, ) -> Union[UNet2DConditionOutput, Tuple]: down_block_add_res = ( - down_block_add_res00, down_block_add_res01, down_block_add_res02, - down_block_add_res03, down_block_add_res04, down_block_add_res05, - down_block_add_res06, down_block_add_res07, down_block_add_res08, - down_block_add_res09, down_block_add_res10, down_block_add_res11) - return super().forward( - sample = sample, - timestep = timestep, - encoder_hidden_states = encoder_hidden_states, - down_block_additional_residuals = down_block_add_res, - mid_block_additional_residual = mid_block_additional_residual, - return_dict = return_dict + down_block_add_res00, + down_block_add_res01, + down_block_add_res02, + down_block_add_res03, + down_block_add_res04, + down_block_add_res05, + down_block_add_res06, + down_block_add_res07, + down_block_add_res08, + down_block_add_res09, + down_block_add_res10, + down_block_add_res11, + ) + return super().forward( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=down_block_add_res, + mid_block_additional_residual=mid_block_additional_residual, + return_dict=return_dict, )