From e2c9389d6ee86dadee5dba3b3928c8db589b4793 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Dec 2023 22:57:02 -0600 Subject: [PATCH] add fallback dtypes to unet patch --- api/onnx_web/diffusers/patches/unet.py | 42 ++++++++++++++++++++------ 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index eef1d641..f4609e83 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -15,7 +15,9 @@ class UNetWrapper(object): input_types: Optional[Dict[str, np.dtype]] = None prompt_embeds: Optional[List[np.ndarray]] = None prompt_index: int = 0 + sample_dtype: np.dtype server: ServerContext + timestep_dtype: np.dtype wrapped: Union[OnnxRuntimeModel, ORTModelUnet] xl: bool @@ -24,10 +26,14 @@ class UNetWrapper(object): server: ServerContext, wrapped: Union[OnnxRuntimeModel, ORTModelUnet], xl: bool, + sample_dtype: Optional[np.dtype] = None, + timestep_dtype: np.dtype = np.int64, ): self.server = server self.wrapped = wrapped self.xl = xl + self.sample_dtype = sample_dtype or server.torch_dtype + self.timestep_dtype = timestep_dtype self.cache_input_types() @@ -54,19 +60,36 @@ class UNetWrapper(object): if self.input_types is None: self.cache_input_types() - if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]: - logger.trace("converting UNet hidden states to input dtype") + encoder_hidden_states_input_dtype = self.input_types.get( + "encoder_hidden_states", self.sample_dtype + ) + if encoder_hidden_states.dtype != encoder_hidden_states_input_dtype: + logger.debug( + "converting UNet hidden states to input dtype from %s to %s", + encoder_hidden_states.dtype, + encoder_hidden_states_input_dtype, + ) encoder_hidden_states = encoder_hidden_states.astype( - self.input_types["encoder_hidden_states"] + encoder_hidden_states_input_dtype ) - if sample.dtype != self.input_types["sample"]: - logger.trace("converting UNet sample to input dtype") - sample = sample.astype(self.input_types["sample"]) + sample_input_dtype = self.input_types.get("sample", self.sample_dtype) + if sample.dtype != sample_input_dtype: + logger.debug( + "converting UNet sample to input dtype from %s to %s", + sample.dtype, + sample_input_dtype, + ) + sample = sample.astype(sample_input_dtype) - if timestep.dtype != self.input_types["timestep"]: - logger.trace("converting UNet timestep to input dtype") - timestep = timestep.astype(self.input_types["timestep"]) + timestep_input_dtype = self.input_types.get("timestep", self.timestep_dtype) + if timestep.dtype != timestep_input_dtype: + logger.debug( + "converting UNet timestep to input dtype from %s to %s", + timestep.dtype, + timestep_input_dtype, + ) + timestep = timestep.astype(timestep_input_dtype) return self.wrapped( sample=sample, @@ -79,7 +102,6 @@ class UNetWrapper(object): return getattr(self.wrapped, attr) def cache_input_types(self): - # TODO: use server dtype as default if isinstance(self.wrapped, ORTModelUnet): session = self.wrapped.session elif isinstance(self.wrapped, OnnxRuntimeModel):