From 9f9b73b780d926dff240741ee3ef90c72d4b10a7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 13:39:04 -0500 Subject: [PATCH] fix(api): use ORT session for correct device when loading blended nets --- api/onnx_web/diffusers/load.py | 5 ++--- api/onnx_web/params.py | 8 +++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 857874e6..f0ffe873 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -22,7 +22,6 @@ from diffusers import ( StableDiffusionPipeline, ) from onnx import load_model -from onnxruntime import SessionOptions from transformers import CLIPTokenizer from onnx_web.diffusers.utils import expand_prompt @@ -271,7 +270,7 @@ def load_pipeline( text_encoder ) text_encoder_names, text_encoder_values = zip(*text_encoder_data) - text_encoder_opts = SessionOptions() + text_encoder_opts = device.sess_options(cache=False) text_encoder_opts.add_external_initializers( list(text_encoder_names), list(text_encoder_values) ) @@ -292,7 +291,7 @@ def load_pipeline( ) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) unet_names, unet_values = zip(*unet_data) - unet_opts = SessionOptions() + unet_opts = device.sess_options(cache=False) unet_opts.add_external_initializers(list(unet_names), list(unet_values)) components["unet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 8d238046..431745fd 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -113,8 +113,8 @@ class DeviceParams: else: return (self.provider, self.options) - def sess_options(self) -> SessionOptions: - if self.sess_options_cache is not None: + def sess_options(self, cache = True) -> SessionOptions: + if cache and self.sess_options_cache is not None: return self.sess_options_cache sess = SessionOptions() @@ -139,7 +139,9 @@ class DeviceParams: logger.debug("enabling ONNX deterministic compute") sess.use_deterministic_compute = True - self.sess_options_cache = sess + if cache: + self.sess_options_cache = sess + return sess def torch_str(self) -> str: