diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index cef594c1..3c0affc1 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Union import numpy as np from diffusers import OnnxRuntimeModel +from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from optimum.onnxruntime.modeling_diffusion import ORTModelUnet from ...server import ServerContext @@ -87,7 +88,7 @@ class UNetWrapper(object): raise ValueError() inputs = session.get_inputs() - self.input_types = dict([(input.name, input.type) for input in inputs]) + self.input_types = dict([(input.name, TENSOR_TYPE_TO_NP_TYPE[input.type]) for input in inputs]) logger.debug("cached UNet input types: %s", self.input_types) # [