1
0
Fork 0

use right type mapping

This commit is contained in:
Sean Sube 2023-12-24 22:46:22 -06:00
parent 0f6a1a82a2
commit ef256280b4
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 4 additions and 2 deletions

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Union
import numpy as np
from diffusers import OnnxRuntimeModel
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
from ...server import ServerContext
@ -88,7 +88,9 @@ class UNetWrapper(object):
raise ValueError()
inputs = session.get_inputs()
self.input_types = dict([(input.name, TENSOR_TYPE_TO_NP_TYPE[input.type]) for input in inputs])
self.input_types = dict(
[(input.name, ORT_TO_NP_TYPE[input.type]) for input in inputs]
)
logger.debug("cached UNet input types: %s", self.input_types)
# [