From 0f6a1a82a24cee924f2a5e4afe3588b156fa825c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Dec 2023 22:40:07 -0600 Subject: [PATCH] translate types to np --- api/onnx_web/diffusers/patches/unet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index cef594c1..3c0affc1 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Union import numpy as np from diffusers import OnnxRuntimeModel +from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from optimum.onnxruntime.modeling_diffusion import ORTModelUnet from ...server import ServerContext @@ -87,7 +88,7 @@ class UNetWrapper(object): raise ValueError() inputs = session.get_inputs() - self.input_types = dict([(input.name, input.type) for input in inputs]) + self.input_types = dict([(input.name, TENSOR_TYPE_TO_NP_TYPE[input.type]) for input in inputs]) logger.debug("cached UNet input types: %s", self.input_types) # [