1
0
Fork 0
onnx-web/api/onnx_web/onnx/onnx_net.py

74 lines
1.7 KiB
Python
Raw Normal View History

from os import path
from typing import Any, Optional
import numpy as np
import torch
2023-02-05 13:53:26 +00:00
2023-02-19 02:28:21 +00:00
from ..server import ServerContext
2023-02-26 21:21:58 +00:00
from ..torch_before_ort import InferenceSession, SessionOptions
class OnnxTensor:
def __init__(self, source) -> None:
self.source = source
self.data = self
def __getitem__(self, *args):
return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32)
def squeeze(self):
self.source = np.squeeze(self.source, (0))
return self
def float(self):
return self
def cpu(self):
return self
def clamp_(self, min, max):
self.source = np.clip(self.source, min, max)
return self
def numpy(self):
return self.source
def size(self):
return np.shape(self.source)
class OnnxRRDBNet:
2023-02-05 13:53:26 +00:00
"""
Provides the RRDBNet interface using an ONNX session.
2023-02-05 13:53:26 +00:00
"""
def __init__(
self,
server: ServerContext,
model: str,
2023-02-05 13:53:26 +00:00
provider: str = "DmlExecutionProvider",
sess_options: Optional[SessionOptions] = None,
) -> None:
model_path = path.join(server.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider], provider_options=sess_options
2023-02-05 13:53:26 +00:00
)
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
2023-02-18 14:46:46 +00:00
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
return OnnxTensor(output)
def eval(self) -> None:
pass
def half(self):
return self
2023-03-21 22:11:38 +00:00
def load_state_dict(self, _net, strict=True) -> None:
pass
2023-03-21 22:11:38 +00:00
def to(self, _device):
return self