translate types to np
This commit is contained in:
parent
39ee4cbfcd
commit
0f6a1a82a2
|
@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import OnnxRuntimeModel
|
from diffusers import OnnxRuntimeModel
|
||||||
|
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||||
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
|
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
|
||||||
|
|
||||||
from ...server import ServerContext
|
from ...server import ServerContext
|
||||||
|
@ -87,7 +88,7 @@ class UNetWrapper(object):
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
inputs = session.get_inputs()
|
inputs = session.get_inputs()
|
||||||
self.input_types = dict([(input.name, input.type) for input in inputs])
|
self.input_types = dict([(input.name, TENSOR_TYPE_TO_NP_TYPE[input.type]) for input in inputs])
|
||||||
logger.debug("cached UNet input types: %s", self.input_types)
|
logger.debug("cached UNet input types: %s", self.input_types)
|
||||||
|
|
||||||
# [
|
# [
|
||||||
|
|
Loading…
Reference in New Issue