55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
from logging import getLogger
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from diffusers import UNet2DConditionModel
|
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class UNet2DConditionModel_CNet(UNet2DConditionModel):
|
|
def forward(
|
|
self,
|
|
sample: torch.FloatTensor,
|
|
timestep: Union[torch.Tensor, float, int],
|
|
encoder_hidden_states: torch.Tensor,
|
|
down_block_add_res00: Optional[torch.Tensor] = None,
|
|
down_block_add_res01: Optional[torch.Tensor] = None,
|
|
down_block_add_res02: Optional[torch.Tensor] = None,
|
|
down_block_add_res03: Optional[torch.Tensor] = None,
|
|
down_block_add_res04: Optional[torch.Tensor] = None,
|
|
down_block_add_res05: Optional[torch.Tensor] = None,
|
|
down_block_add_res06: Optional[torch.Tensor] = None,
|
|
down_block_add_res07: Optional[torch.Tensor] = None,
|
|
down_block_add_res08: Optional[torch.Tensor] = None,
|
|
down_block_add_res09: Optional[torch.Tensor] = None,
|
|
down_block_add_res10: Optional[torch.Tensor] = None,
|
|
down_block_add_res11: Optional[torch.Tensor] = None,
|
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
|
return_dict: bool = False,
|
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
|
down_block_add_res = (
|
|
down_block_add_res00,
|
|
down_block_add_res01,
|
|
down_block_add_res02,
|
|
down_block_add_res03,
|
|
down_block_add_res04,
|
|
down_block_add_res05,
|
|
down_block_add_res06,
|
|
down_block_add_res07,
|
|
down_block_add_res08,
|
|
down_block_add_res09,
|
|
down_block_add_res10,
|
|
down_block_add_res11,
|
|
)
|
|
return super().forward(
|
|
sample=sample,
|
|
timestep=timestep,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
down_block_additional_residuals=down_block_add_res,
|
|
mid_block_additional_residual=mid_block_additional_residual,
|
|
return_dict=return_dict,
|
|
)
|