use right type mapping
This commit is contained in:
parent
0f6a1a82a2
commit
ef256280b4
|
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Union
|
|||
|
||||
import numpy as np
|
||||
from diffusers import OnnxRuntimeModel
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
|
||||
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
|
||||
|
||||
from ...server import ServerContext
|
||||
|
@ -88,7 +88,9 @@ class UNetWrapper(object):
|
|||
raise ValueError()
|
||||
|
||||
inputs = session.get_inputs()
|
||||
self.input_types = dict([(input.name, TENSOR_TYPE_TO_NP_TYPE[input.type]) for input in inputs])
|
||||
self.input_types = dict(
|
||||
[(input.name, ORT_TO_NP_TYPE[input.type]) for input in inputs]
|
||||
)
|
||||
logger.debug("cached UNet input types: %s", self.input_types)
|
||||
|
||||
# [
|
||||
|
|
Loading…
Reference in New Issue