From 8a2a9174ba01095850c53875b1082cfc87755da9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 14 Feb 2023 18:57:50 -0600 Subject: [PATCH] fix(api): pass both device and session options to ORT (#38) --- api/onnx_web/chain/upscale_resrgan.py | 4 ++-- api/onnx_web/chain/upscale_stable_diffusion.py | 7 +++++-- api/onnx_web/diffusion/load.py | 12 ++++++------ api/onnx_web/onnx/onnx_net.py | 4 ++-- api/onnx_web/params.py | 10 ++++++++++ 5 files changed, 25 insertions(+), 12 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 41d7bfc9..deff5894 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -52,8 +52,8 @@ def load_resrgan( model = OnnxNet( server, model_file, - provider=device.provider, - sess_options=device.options, + provider=device.ort_provider(), + sess_options=device.sess_options(), ) elif params.format == "pth": model = RRDBNet( diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 67c93e4d..0de51883 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -34,7 +34,9 @@ def load_stable_diffusion( device.provider, ) pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained( - model_path, provider=device.provider, sess_options=device.options + model_path, + provider=device.ort_provider(), + sess_options=device.sess_options(), ) else: logger.debug( @@ -43,7 +45,8 @@ def load_stable_diffusion( device.provider, ) pipe = StableDiffusionUpscalePipeline.from_pretrained( - model_path, provider=device.provider + model_path, + provider=device.provider, ) server.cache.set("diffusion", cache_key, pipe) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 0e3ad495..82599299 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -109,8 +109,8 @@ def load_pipeline( logger.debug("loading new diffusion scheduler") scheduler = scheduler_type.from_pretrained( model, - provider=device.provider, - sess_options=device.options, + provider=device.ort_provider(), + sess_options=device.sess_options(), subfolder="scheduler", ) @@ -134,15 +134,15 @@ def load_pipeline( logger.debug("loading new diffusion pipeline from %s", model) scheduler = scheduler_type.from_pretrained( model, - provider=device.provider, - sess_options=device.options, + provider=device.ort_provider(), + sess_options=device.sess_options(), subfolder="scheduler", ) pipe = pipeline.from_pretrained( model, custom_pipeline=custom_pipeline, - provider=device.provider, - sess_options=device.options, + provider=device.ort_provider(), + sess_options=device.sess_options(), revision="onnx", safety_checker=None, scheduler=scheduler, diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 950c0ce7..51697380 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -3,7 +3,7 @@ from typing import Any, Optional import numpy as np import torch -from onnxruntime import InferenceSession +from onnxruntime import InferenceSession, SessionOptions from ..utils import ServerContext @@ -47,7 +47,7 @@ class OnnxNet: server: ServerContext, model: str, provider: str = "DmlExecutionProvider", - sess_options: Optional[dict] = None, + sess_options: Optional[SessionOptions] = None, ) -> None: model_path = path.join(server.model_path, model) self.session = InferenceSession( diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index a6144c24..39990d89 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -1,4 +1,5 @@ from enum import IntEnum +from onnxruntime import SessionOptions from typing import Any, Dict, Literal, Optional, Tuple, Union @@ -79,6 +80,15 @@ class DeviceParams: def __str__(self) -> str: return "%s - %s (%s)" % (self.device, self.provider, self.options) + def ort_provider(self) -> Tuple[str, Any]: + if self.options is None: + return self.provider + else: + return (self.provider, self.options) + + def sess_options(self) -> SessionOptions: + return SessionOptions() + def torch_device(self) -> str: if self.device.startswith("cuda"): return self.device