1
0
Fork 0

feat(api): use wrapped model's input types in UNet patch

This commit is contained in:
Sean Sube 2023-12-24 22:21:52 -06:00
parent 1d373faf5d
commit 80a255397e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 39 additions and 17 deletions

View File

@ -1,8 +1,9 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Dict, List, Optional
import numpy as np import numpy as np
from diffusers import OnnxRuntimeModel from diffusers import OnnxRuntimeModel
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from ...server import ServerContext from ...server import ServerContext
@ -10,6 +11,7 @@ logger = getLogger(__name__)
class UNetWrapper(object): class UNetWrapper(object):
input_types: Optional[Dict[str, np.dtype]] = None
prompt_embeds: Optional[List[np.ndarray]] = None prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0 prompt_index: int = 0
server: ServerContext server: ServerContext
@ -26,6 +28,8 @@ class UNetWrapper(object):
self.wrapped = wrapped self.wrapped = wrapped
self.xl = xl self.xl = xl
self.cache_input_types()
def __call__( def __call__(
self, self,
sample: Optional[np.ndarray] = None, sample: Optional[np.ndarray] = None,
@ -46,23 +50,22 @@ class UNetWrapper(object):
encoder_hidden_states = self.prompt_embeds[step_index] encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1 self.prompt_index += 1
if self.xl: if self.input_types is None:
# for XL, the sample and hidden states should match self.cache_input_types()
if sample.dtype != encoder_hidden_states.dtype:
logger.trace(
"converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype,
)
sample = sample.astype(encoder_hidden_states.dtype)
elif timestep.dtype != np.int64:
# the optimum converter uses an int timestep
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if encoder_hidden_states.dtype != timestep.dtype: if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]:
logger.trace("converting UNet hidden states to timestep dtype") logger.trace("converting UNet hidden states to input dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) encoder_hidden_states = encoder_hidden_states.astype(
self.input_types["encoder_hidden_states"]
)
if sample.dtype != self.input_types["sample"]:
logger.trace("converting UNet sample to input dtype")
sample = sample.astype(self.input_types["sample"])
if timestep.dtype != self.input_types["timestep"]:
logger.trace("converting UNet timestep to input dtype")
timestep = timestep.astype(self.input_types["timestep"])
return self.wrapped( return self.wrapped(
sample=sample, sample=sample,
@ -74,6 +77,25 @@ class UNetWrapper(object):
def __getattr__(self, attr): def __getattr__(self, attr):
return getattr(self.wrapped, attr) return getattr(self.wrapped, attr)
def cache_input_types(self):
# TODO: use server dtype as default
self.input_types = dict(
[
(
input.name,
next(
[
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
for field in input.type.ListFields()
],
np.float32,
),
)
for input in self.wrapped.model.graph.input
]
)
logger.debug("cached UNet input types: %s", self.input_types)
def set_prompts(self, prompt_embeds: List[np.ndarray]): def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug( logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds] "setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]