From d473a0fd2db1fc5bfca207b8e5c248673c5e7338 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 14 Feb 2023 17:12:52 -0600 Subject: [PATCH] fix(api): pass device options to ORT session (#38) --- api/onnx_web/chain/upscale_resrgan.py | 2 +- api/onnx_web/chain/upscale_stable_diffusion.py | 2 +- api/onnx_web/diffusion/load.py | 6 +++--- api/onnx_web/onnx/onnx_net.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index f311945e..41d7bfc9 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -53,7 +53,7 @@ def load_resrgan( server, model_file, provider=device.provider, - provider_options=device.options, + sess_options=device.options, ) elif params.format == "pth": model = RRDBNet( diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 8b540b68..67c93e4d 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -34,7 +34,7 @@ def load_stable_diffusion( device.provider, ) pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained( - model_path, provider=device.provider, provider_options=device.options + model_path, provider=device.provider, sess_options=device.options ) else: logger.debug( diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 8a1f25e3..0e3ad495 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -110,7 +110,7 @@ def load_pipeline( scheduler = scheduler_type.from_pretrained( model, provider=device.provider, - provider_options=device.options, + sess_options=device.options, subfolder="scheduler", ) @@ -135,14 +135,14 @@ def load_pipeline( scheduler = scheduler_type.from_pretrained( model, provider=device.provider, - provider_options=device.options, + sess_options=device.options, subfolder="scheduler", ) pipe = pipeline.from_pretrained( model, custom_pipeline=custom_pipeline, provider=device.provider, - provider_options=device.options, + sess_options=device.options, revision="onnx", safety_checker=None, scheduler=scheduler, diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 9e6c1d7e..950c0ce7 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -47,11 +47,11 @@ class OnnxNet: server: ServerContext, model: str, provider: str = "DmlExecutionProvider", - provider_options: Optional[dict] = None, + sess_options: Optional[dict] = None, ) -> None: model_path = path.join(server.model_path, model) self.session = InferenceSession( - model_path, providers=[provider], provider_options=provider_options + model_path, providers=[provider], provider_options=sess_options ) def __call__(self, image: Any) -> Any: