1
0
Fork 0

feat(api): use wrapped model's input types in UNet patch

This commit is contained in:
Sean Sube 2023-12-24 22:21:52 -06:00
parent 1d373faf5d
commit 80a255397e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 39 additions and 17 deletions

View File

@ -1,8 +1,9 @@
from logging import getLogger
from typing import List, Optional
from typing import Dict, List, Optional
import numpy as np
from diffusers import OnnxRuntimeModel
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from ...server import ServerContext
@ -10,6 +11,7 @@ logger = getLogger(__name__)
class UNetWrapper(object):
input_types: Optional[Dict[str, np.dtype]] = None
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
server: ServerContext
@ -26,6 +28,8 @@ class UNetWrapper(object):
self.wrapped = wrapped
self.xl = xl
self.cache_input_types()
def __call__(
self,
sample: Optional[np.ndarray] = None,
@ -46,23 +50,22 @@ class UNetWrapper(object):
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1
if self.xl:
# for XL, the sample and hidden states should match
if sample.dtype != encoder_hidden_states.dtype:
logger.trace(
"converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype,
)
sample = sample.astype(encoder_hidden_states.dtype)
elif timestep.dtype != np.int64:
# the optimum converter uses an int timestep
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if self.input_types is None:
self.cache_input_types()
if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]:
logger.trace("converting UNet hidden states to input dtype")
encoder_hidden_states = encoder_hidden_states.astype(
self.input_types["encoder_hidden_states"]
)
if sample.dtype != self.input_types["sample"]:
logger.trace("converting UNet sample to input dtype")
sample = sample.astype(self.input_types["sample"])
if timestep.dtype != self.input_types["timestep"]:
logger.trace("converting UNet timestep to input dtype")
timestep = timestep.astype(self.input_types["timestep"])
return self.wrapped(
sample=sample,
@ -74,6 +77,25 @@ class UNetWrapper(object):
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def cache_input_types(self):
# TODO: use server dtype as default
self.input_types = dict(
[
(
input.name,
next(
[
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
for field in input.type.ListFields()
],
np.float32,
),
)
for input in self.wrapped.model.graph.input
]
)
logger.debug("cached UNet input types: %s", self.input_types)
def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]