1
0
Fork 0
onnx-web/api/onnx_web/models/cnet.py

55 lines
2.1 KiB
Python

from logging import getLogger
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
logger = getLogger(__name__)
class UNet2DConditionModel_CNet(UNet2DConditionModel):
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
down_block_add_res00: Optional[torch.Tensor] = None,
down_block_add_res01: Optional[torch.Tensor] = None,
down_block_add_res02: Optional[torch.Tensor] = None,
down_block_add_res03: Optional[torch.Tensor] = None,
down_block_add_res04: Optional[torch.Tensor] = None,
down_block_add_res05: Optional[torch.Tensor] = None,
down_block_add_res06: Optional[torch.Tensor] = None,
down_block_add_res07: Optional[torch.Tensor] = None,
down_block_add_res08: Optional[torch.Tensor] = None,
down_block_add_res09: Optional[torch.Tensor] = None,
down_block_add_res10: Optional[torch.Tensor] = None,
down_block_add_res11: Optional[torch.Tensor] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = False,
) -> Union[UNet2DConditionOutput, Tuple]:
down_block_add_res = (
down_block_add_res00,
down_block_add_res01,
down_block_add_res02,
down_block_add_res03,
down_block_add_res04,
down_block_add_res05,
down_block_add_res06,
down_block_add_res07,
down_block_add_res08,
down_block_add_res09,
down_block_add_res10,
down_block_add_res11,
)
return super().forward(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_add_res,
mid_block_additional_residual=mid_block_additional_residual,
return_dict=return_dict,
)