2023-04-28 03:50:11 +00:00
|
|
|
from logging import getLogger
|
2023-12-25 04:36:39 +00:00
|
|
|
from typing import Dict, List, Optional, Union
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-04-28 18:56:36 +00:00
|
|
|
import numpy as np
|
|
|
|
from diffusers import OnnxRuntimeModel
|
2023-12-25 04:46:22 +00:00
|
|
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
|
2023-12-25 04:36:39 +00:00
|
|
|
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
|
2023-04-28 18:56:36 +00:00
|
|
|
|
2023-04-28 03:50:11 +00:00
|
|
|
from ...server import ServerContext
|
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class UNetWrapper(object):
|
2023-12-25 04:21:52 +00:00
|
|
|
input_types: Optional[Dict[str, np.dtype]] = None
|
2023-04-28 03:50:11 +00:00
|
|
|
prompt_embeds: Optional[List[np.ndarray]] = None
|
|
|
|
prompt_index: int = 0
|
2023-12-25 04:57:02 +00:00
|
|
|
sample_dtype: np.dtype
|
2023-04-28 03:50:11 +00:00
|
|
|
server: ServerContext
|
2023-12-25 04:57:02 +00:00
|
|
|
timestep_dtype: np.dtype
|
2023-12-25 04:36:39 +00:00
|
|
|
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
|
2023-09-21 00:28:34 +00:00
|
|
|
xl: bool
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
server: ServerContext,
|
2023-12-25 04:36:39 +00:00
|
|
|
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
|
2023-09-21 00:28:34 +00:00
|
|
|
xl: bool,
|
2023-12-25 04:57:02 +00:00
|
|
|
sample_dtype: Optional[np.dtype] = None,
|
|
|
|
timestep_dtype: np.dtype = np.int64,
|
2023-04-28 03:50:11 +00:00
|
|
|
):
|
|
|
|
self.server = server
|
|
|
|
self.wrapped = wrapped
|
2023-09-21 00:28:34 +00:00
|
|
|
self.xl = xl
|
2023-12-25 04:57:02 +00:00
|
|
|
self.sample_dtype = sample_dtype or server.torch_dtype
|
|
|
|
self.timestep_dtype = timestep_dtype
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-12-25 04:21:52 +00:00
|
|
|
self.cache_input_types()
|
|
|
|
|
2023-04-28 03:50:11 +00:00
|
|
|
def __call__(
|
|
|
|
self,
|
2023-12-03 18:53:50 +00:00
|
|
|
sample: Optional[np.ndarray] = None,
|
|
|
|
timestep: Optional[np.ndarray] = None,
|
|
|
|
encoder_hidden_states: Optional[np.ndarray] = None,
|
2023-04-28 03:50:11 +00:00
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
logger.trace(
|
|
|
|
"UNet parameter types: %s, %s, %s",
|
|
|
|
sample.dtype,
|
|
|
|
timestep.dtype,
|
|
|
|
encoder_hidden_states.dtype,
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.prompt_embeds is not None:
|
|
|
|
step_index = self.prompt_index % len(self.prompt_embeds)
|
|
|
|
logger.trace("multiple prompt embeds found, using step: %s", step_index)
|
|
|
|
encoder_hidden_states = self.prompt_embeds[step_index]
|
|
|
|
self.prompt_index += 1
|
|
|
|
|
2023-12-25 04:21:52 +00:00
|
|
|
if self.input_types is None:
|
|
|
|
self.cache_input_types()
|
|
|
|
|
2023-12-25 04:57:02 +00:00
|
|
|
encoder_hidden_states_input_dtype = self.input_types.get(
|
|
|
|
"encoder_hidden_states", self.sample_dtype
|
|
|
|
)
|
|
|
|
if encoder_hidden_states.dtype != encoder_hidden_states_input_dtype:
|
|
|
|
logger.debug(
|
|
|
|
"converting UNet hidden states to input dtype from %s to %s",
|
|
|
|
encoder_hidden_states.dtype,
|
|
|
|
encoder_hidden_states_input_dtype,
|
|
|
|
)
|
2023-12-25 04:21:52 +00:00
|
|
|
encoder_hidden_states = encoder_hidden_states.astype(
|
2023-12-25 04:57:02 +00:00
|
|
|
encoder_hidden_states_input_dtype
|
2023-12-25 04:21:52 +00:00
|
|
|
)
|
|
|
|
|
2023-12-25 04:57:02 +00:00
|
|
|
sample_input_dtype = self.input_types.get("sample", self.sample_dtype)
|
|
|
|
if sample.dtype != sample_input_dtype:
|
|
|
|
logger.debug(
|
|
|
|
"converting UNet sample to input dtype from %s to %s",
|
|
|
|
sample.dtype,
|
|
|
|
sample_input_dtype,
|
|
|
|
)
|
|
|
|
sample = sample.astype(sample_input_dtype)
|
|
|
|
|
|
|
|
timestep_input_dtype = self.input_types.get("timestep", self.timestep_dtype)
|
|
|
|
if timestep.dtype != timestep_input_dtype:
|
|
|
|
logger.debug(
|
|
|
|
"converting UNet timestep to input dtype from %s to %s",
|
|
|
|
timestep.dtype,
|
|
|
|
timestep_input_dtype,
|
|
|
|
)
|
|
|
|
timestep = timestep.astype(timestep_input_dtype)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
return self.wrapped(
|
|
|
|
sample=sample,
|
|
|
|
timestep=timestep,
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
return getattr(self.wrapped, attr)
|
|
|
|
|
2023-12-25 04:21:52 +00:00
|
|
|
def cache_input_types(self):
|
2023-12-25 04:36:39 +00:00
|
|
|
if isinstance(self.wrapped, ORTModelUnet):
|
|
|
|
session = self.wrapped.session
|
|
|
|
elif isinstance(self.wrapped, OnnxRuntimeModel):
|
|
|
|
session = self.wrapped.model
|
|
|
|
else:
|
2023-12-25 05:10:08 +00:00
|
|
|
raise ValueError("unknown UNet class")
|
2023-12-25 04:36:39 +00:00
|
|
|
|
|
|
|
inputs = session.get_inputs()
|
2023-12-25 04:46:22 +00:00
|
|
|
self.input_types = dict(
|
|
|
|
[(input.name, ORT_TO_NP_TYPE[input.type]) for input in inputs]
|
|
|
|
)
|
2023-12-25 04:21:52 +00:00
|
|
|
logger.debug("cached UNet input types: %s", self.input_types)
|
|
|
|
|
2023-04-28 03:50:11 +00:00
|
|
|
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
|
|
|
logger.debug(
|
|
|
|
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
|
|
|
)
|
|
|
|
self.prompt_embeds = prompt_embeds
|
|
|
|
self.prompt_index = 0
|