diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index 81065d97..b7c020de 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -47,13 +47,15 @@ class UNetWrapper(object): self.prompt_index += 1 if self.xl: + # for XL, the sample and hidden states should match if sample.dtype != encoder_hidden_states.dtype: logger.trace( "converting UNet sample to hidden state dtype for XL: %s", encoder_hidden_states.dtype, ) sample = sample.astype(encoder_hidden_states.dtype) - else: + elif timestep.dtype != np.int64: + # the optimum converter uses an int timestep if sample.dtype != timestep.dtype: logger.trace("converting UNet sample to timestep dtype") sample = sample.astype(timestep.dtype)