From ef256280b48c5c47068ecb481387455ff9e00b3c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Dec 2023 22:46:22 -0600 Subject: [PATCH] use right type mapping --- api/onnx_web/diffusers/patches/unet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index 3c0affc1..eef1d641 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Union import numpy as np from diffusers import OnnxRuntimeModel -from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE +from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE from optimum.onnxruntime.modeling_diffusion import ORTModelUnet from ...server import ServerContext @@ -88,7 +88,9 @@ class UNetWrapper(object): raise ValueError() inputs = session.get_inputs() - self.input_types = dict([(input.name, TENSOR_TYPE_TO_NP_TYPE[input.type]) for input in inputs]) + self.input_types = dict( + [(input.name, ORT_TO_NP_TYPE[input.type]) for input in inputs] + ) logger.debug("cached UNet input types: %s", self.input_types) # [