1
0
Fork 0

fix(api): use ORT session for correct device when loading blended nets

This commit is contained in:
Sean Sube 2023-03-18 13:39:04 -05:00
parent c465b61fb5
commit 9f9b73b780
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 7 additions and 6 deletions

View File

@ -22,7 +22,6 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from onnx import load_model from onnx import load_model
from onnxruntime import SessionOptions
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from onnx_web.diffusers.utils import expand_prompt from onnx_web.diffusers.utils import expand_prompt
@ -271,7 +270,7 @@ def load_pipeline(
text_encoder text_encoder
) )
text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions() text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers( text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values) list(text_encoder_names), list(text_encoder_values)
) )
@ -292,7 +291,7 @@ def load_pipeline(
) )
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions() unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values)) unet_opts.add_external_initializers(list(unet_names), list(unet_values))
components["unet"] = OnnxRuntimeModel( components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(

View File

@ -113,8 +113,8 @@ class DeviceParams:
else: else:
return (self.provider, self.options) return (self.provider, self.options)
def sess_options(self) -> SessionOptions: def sess_options(self, cache = True) -> SessionOptions:
if self.sess_options_cache is not None: if cache and self.sess_options_cache is not None:
return self.sess_options_cache return self.sess_options_cache
sess = SessionOptions() sess = SessionOptions()
@ -139,7 +139,9 @@ class DeviceParams:
logger.debug("enabling ONNX deterministic compute") logger.debug("enabling ONNX deterministic compute")
sess.use_deterministic_compute = True sess.use_deterministic_compute = True
self.sess_options_cache = sess if cache:
self.sess_options_cache = sess
return sess return sess
def torch_str(self) -> str: def torch_str(self) -> str: