1
0
Fork 0
onnx-web/api/onnx_web/models/onnx.py

29 lines
868 B
Python
Raw Normal View History

from logging import getLogger
from os import path
from typing import Any, Optional
from ..server import ServerContext
from ..torch_before_ort import InferenceSession, SessionOptions
logger = getLogger(__name__)
class OnnxModel:
def __init__(
self,
server: ServerContext,
model: str,
provider: str = "DmlExecutionProvider",
sess_options: Optional[SessionOptions] = None,
) -> None:
model_path = path.join(server.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider], provider_options=sess_options
)
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {input_name: image})[0]
return output