1
0
Fork 0

translate types to np

This commit is contained in:
Sean Sube 2023-12-24 22:40:07 -06:00
parent 39ee4cbfcd
commit 0f6a1a82a2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 2 additions and 1 deletions

View File

@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from diffusers import OnnxRuntimeModel from diffusers import OnnxRuntimeModel
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
from ...server import ServerContext from ...server import ServerContext
@ -87,7 +88,7 @@ class UNetWrapper(object):
raise ValueError() raise ValueError()
inputs = session.get_inputs() inputs = session.get_inputs()
self.input_types = dict([(input.name, input.type) for input in inputs]) self.input_types = dict([(input.name, TENSOR_TYPE_TO_NP_TYPE[input.type]) for input in inputs])
logger.debug("cached UNet input types: %s", self.input_types) logger.debug("cached UNet input types: %s", self.input_types)
# [ # [