fix(api): use ORT session for correct device when loading blended nets
This commit is contained in:
parent
c465b61fb5
commit
9f9b73b780
|
@ -22,7 +22,6 @@ from diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
)
|
||||
from onnx import load_model
|
||||
from onnxruntime import SessionOptions
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from onnx_web.diffusers.utils import expand_prompt
|
||||
|
@ -271,7 +270,7 @@ def load_pipeline(
|
|||
text_encoder
|
||||
)
|
||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||
text_encoder_opts = SessionOptions()
|
||||
text_encoder_opts = device.sess_options(cache=False)
|
||||
text_encoder_opts.add_external_initializers(
|
||||
list(text_encoder_names), list(text_encoder_values)
|
||||
)
|
||||
|
@ -292,7 +291,7 @@ def load_pipeline(
|
|||
)
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||
unet_names, unet_values = zip(*unet_data)
|
||||
unet_opts = SessionOptions()
|
||||
unet_opts = device.sess_options(cache=False)
|
||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
|
|
@ -113,8 +113,8 @@ class DeviceParams:
|
|||
else:
|
||||
return (self.provider, self.options)
|
||||
|
||||
def sess_options(self) -> SessionOptions:
|
||||
if self.sess_options_cache is not None:
|
||||
def sess_options(self, cache = True) -> SessionOptions:
|
||||
if cache and self.sess_options_cache is not None:
|
||||
return self.sess_options_cache
|
||||
|
||||
sess = SessionOptions()
|
||||
|
@ -139,7 +139,9 @@ class DeviceParams:
|
|||
logger.debug("enabling ONNX deterministic compute")
|
||||
sess.use_deterministic_compute = True
|
||||
|
||||
if cache:
|
||||
self.sess_options_cache = sess
|
||||
|
||||
return sess
|
||||
|
||||
def torch_str(self) -> str:
|
||||
|
|
Loading…
Reference in New Issue