From 80a255397eb76a48b325706212472c1ed7e807c2 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Dec 2023 22:21:52 -0600 Subject: [PATCH] feat(api): use wrapped model's input types in UNet patch --- api/onnx_web/diffusers/patches/unet.py | 56 ++++++++++++++++++-------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index b7c020de..dc7c1c94 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -1,8 +1,9 @@ from logging import getLogger -from typing import List, Optional +from typing import Dict, List, Optional import numpy as np from diffusers import OnnxRuntimeModel +from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from ...server import ServerContext @@ -10,6 +11,7 @@ logger = getLogger(__name__) class UNetWrapper(object): + input_types: Optional[Dict[str, np.dtype]] = None prompt_embeds: Optional[List[np.ndarray]] = None prompt_index: int = 0 server: ServerContext @@ -26,6 +28,8 @@ class UNetWrapper(object): self.wrapped = wrapped self.xl = xl + self.cache_input_types() + def __call__( self, sample: Optional[np.ndarray] = None, @@ -46,23 +50,22 @@ class UNetWrapper(object): encoder_hidden_states = self.prompt_embeds[step_index] self.prompt_index += 1 - if self.xl: - # for XL, the sample and hidden states should match - if sample.dtype != encoder_hidden_states.dtype: - logger.trace( - "converting UNet sample to hidden state dtype for XL: %s", - encoder_hidden_states.dtype, - ) - sample = sample.astype(encoder_hidden_states.dtype) - elif timestep.dtype != np.int64: - # the optimum converter uses an int timestep - if sample.dtype != timestep.dtype: - logger.trace("converting UNet sample to timestep dtype") - sample = sample.astype(timestep.dtype) + if self.input_types is None: + self.cache_input_types() - if encoder_hidden_states.dtype != timestep.dtype: - logger.trace("converting UNet hidden states to timestep dtype") - encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) + if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]: + logger.trace("converting UNet hidden states to input dtype") + encoder_hidden_states = encoder_hidden_states.astype( + self.input_types["encoder_hidden_states"] + ) + + if sample.dtype != self.input_types["sample"]: + logger.trace("converting UNet sample to input dtype") + sample = sample.astype(self.input_types["sample"]) + + if timestep.dtype != self.input_types["timestep"]: + logger.trace("converting UNet timestep to input dtype") + timestep = timestep.astype(self.input_types["timestep"]) return self.wrapped( sample=sample, @@ -74,6 +77,25 @@ class UNetWrapper(object): def __getattr__(self, attr): return getattr(self.wrapped, attr) + def cache_input_types(self): + # TODO: use server dtype as default + self.input_types = dict( + [ + ( + input.name, + next( + [ + TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type] + for field in input.type.ListFields() + ], + np.float32, + ), + ) + for input in self.wrapped.model.graph.input + ] + ) + logger.debug("cached UNet input types: %s", self.input_types) + def set_prompts(self, prompt_embeds: List[np.ndarray]): logger.debug( "setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]