feat(api): use wrapped model's input types in UNet patch
This commit is contained in:
parent
1d373faf5d
commit
80a255397e
|
@ -1,8 +1,9 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxRuntimeModel
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
|
||||
from ...server import ServerContext
|
||||
|
||||
|
@ -10,6 +11,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class UNetWrapper(object):
|
||||
input_types: Optional[Dict[str, np.dtype]] = None
|
||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||
prompt_index: int = 0
|
||||
server: ServerContext
|
||||
|
@ -26,6 +28,8 @@ class UNetWrapper(object):
|
|||
self.wrapped = wrapped
|
||||
self.xl = xl
|
||||
|
||||
self.cache_input_types()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: Optional[np.ndarray] = None,
|
||||
|
@ -46,23 +50,22 @@ class UNetWrapper(object):
|
|||
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||
self.prompt_index += 1
|
||||
|
||||
if self.xl:
|
||||
# for XL, the sample and hidden states should match
|
||||
if sample.dtype != encoder_hidden_states.dtype:
|
||||
logger.trace(
|
||||
"converting UNet sample to hidden state dtype for XL: %s",
|
||||
encoder_hidden_states.dtype,
|
||||
)
|
||||
sample = sample.astype(encoder_hidden_states.dtype)
|
||||
elif timestep.dtype != np.int64:
|
||||
# the optimum converter uses an int timestep
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
if self.input_types is None:
|
||||
self.cache_input_types()
|
||||
|
||||
if encoder_hidden_states.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet hidden states to timestep dtype")
|
||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||
if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]:
|
||||
logger.trace("converting UNet hidden states to input dtype")
|
||||
encoder_hidden_states = encoder_hidden_states.astype(
|
||||
self.input_types["encoder_hidden_states"]
|
||||
)
|
||||
|
||||
if sample.dtype != self.input_types["sample"]:
|
||||
logger.trace("converting UNet sample to input dtype")
|
||||
sample = sample.astype(self.input_types["sample"])
|
||||
|
||||
if timestep.dtype != self.input_types["timestep"]:
|
||||
logger.trace("converting UNet timestep to input dtype")
|
||||
timestep = timestep.astype(self.input_types["timestep"])
|
||||
|
||||
return self.wrapped(
|
||||
sample=sample,
|
||||
|
@ -74,6 +77,25 @@ class UNetWrapper(object):
|
|||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
def cache_input_types(self):
|
||||
# TODO: use server dtype as default
|
||||
self.input_types = dict(
|
||||
[
|
||||
(
|
||||
input.name,
|
||||
next(
|
||||
[
|
||||
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
|
||||
for field in input.type.ListFields()
|
||||
],
|
||||
np.float32,
|
||||
),
|
||||
)
|
||||
for input in self.wrapped.model.graph.input
|
||||
]
|
||||
)
|
||||
logger.debug("cached UNet input types: %s", self.input_types)
|
||||
|
||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||
logger.debug(
|
||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||
|
|
Loading…
Reference in New Issue