feat(api): use wrapped model's input types in UNet patch
This commit is contained in:
parent
1d373faf5d
commit
80a255397e
|
@ -1,8 +1,9 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import OnnxRuntimeModel
|
from diffusers import OnnxRuntimeModel
|
||||||
|
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||||
|
|
||||||
from ...server import ServerContext
|
from ...server import ServerContext
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UNetWrapper(object):
|
class UNetWrapper(object):
|
||||||
|
input_types: Optional[Dict[str, np.dtype]] = None
|
||||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||||
prompt_index: int = 0
|
prompt_index: int = 0
|
||||||
server: ServerContext
|
server: ServerContext
|
||||||
|
@ -26,6 +28,8 @@ class UNetWrapper(object):
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.xl = xl
|
self.xl = xl
|
||||||
|
|
||||||
|
self.cache_input_types()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
sample: Optional[np.ndarray] = None,
|
sample: Optional[np.ndarray] = None,
|
||||||
|
@ -46,23 +50,22 @@ class UNetWrapper(object):
|
||||||
encoder_hidden_states = self.prompt_embeds[step_index]
|
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||||
self.prompt_index += 1
|
self.prompt_index += 1
|
||||||
|
|
||||||
if self.xl:
|
if self.input_types is None:
|
||||||
# for XL, the sample and hidden states should match
|
self.cache_input_types()
|
||||||
if sample.dtype != encoder_hidden_states.dtype:
|
|
||||||
logger.trace(
|
|
||||||
"converting UNet sample to hidden state dtype for XL: %s",
|
|
||||||
encoder_hidden_states.dtype,
|
|
||||||
)
|
|
||||||
sample = sample.astype(encoder_hidden_states.dtype)
|
|
||||||
elif timestep.dtype != np.int64:
|
|
||||||
# the optimum converter uses an int timestep
|
|
||||||
if sample.dtype != timestep.dtype:
|
|
||||||
logger.trace("converting UNet sample to timestep dtype")
|
|
||||||
sample = sample.astype(timestep.dtype)
|
|
||||||
|
|
||||||
if encoder_hidden_states.dtype != timestep.dtype:
|
if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]:
|
||||||
logger.trace("converting UNet hidden states to timestep dtype")
|
logger.trace("converting UNet hidden states to input dtype")
|
||||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
encoder_hidden_states = encoder_hidden_states.astype(
|
||||||
|
self.input_types["encoder_hidden_states"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if sample.dtype != self.input_types["sample"]:
|
||||||
|
logger.trace("converting UNet sample to input dtype")
|
||||||
|
sample = sample.astype(self.input_types["sample"])
|
||||||
|
|
||||||
|
if timestep.dtype != self.input_types["timestep"]:
|
||||||
|
logger.trace("converting UNet timestep to input dtype")
|
||||||
|
timestep = timestep.astype(self.input_types["timestep"])
|
||||||
|
|
||||||
return self.wrapped(
|
return self.wrapped(
|
||||||
sample=sample,
|
sample=sample,
|
||||||
|
@ -74,6 +77,25 @@ class UNetWrapper(object):
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return getattr(self.wrapped, attr)
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
|
def cache_input_types(self):
|
||||||
|
# TODO: use server dtype as default
|
||||||
|
self.input_types = dict(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
input.name,
|
||||||
|
next(
|
||||||
|
[
|
||||||
|
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
|
||||||
|
for field in input.type.ListFields()
|
||||||
|
],
|
||||||
|
np.float32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for input in self.wrapped.model.graph.input
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.debug("cached UNet input types: %s", self.input_types)
|
||||||
|
|
||||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||||
|
|
Loading…
Reference in New Issue