add ORT type table
This commit is contained in:
parent
3d73b9e621
commit
f5ae9dd492
|
@ -23,6 +23,21 @@ logger = getLogger(__name__)
|
||||||
NUM_LATENT_CHANNELS = 4
|
NUM_LATENT_CHANNELS = 4
|
||||||
NUM_UNET_INPUT_CHANNELS = 7
|
NUM_UNET_INPUT_CHANNELS = 7
|
||||||
|
|
||||||
|
ORT_TO_NP_TYPE = {
|
||||||
|
"tensor(bool)": np.bool_,
|
||||||
|
"tensor(int8)": np.int8,
|
||||||
|
"tensor(uint8)": np.uint8,
|
||||||
|
"tensor(int16)": np.int16,
|
||||||
|
"tensor(uint16)": np.uint16,
|
||||||
|
"tensor(int32)": np.int32,
|
||||||
|
"tensor(uint32)": np.uint32,
|
||||||
|
"tensor(int64)": np.int64,
|
||||||
|
"tensor(uint64)": np.uint64,
|
||||||
|
"tensor(float16)": np.float16,
|
||||||
|
"tensor(float)": np.float32,
|
||||||
|
"tensor(double)": np.float64,
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_DTYPES = {
|
TORCH_DTYPES = {
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
"float32": torch.float32,
|
"float32": torch.float32,
|
||||||
|
|
Loading…
Reference in New Issue