fix(api): use ORT session for correct device when loading blended nets
This commit is contained in:
parent
c465b61fb5
commit
9f9b73b780
|
@ -22,7 +22,6 @@ from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
from onnx import load_model
|
from onnx import load_model
|
||||||
from onnxruntime import SessionOptions
|
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from onnx_web.diffusers.utils import expand_prompt
|
from onnx_web.diffusers.utils import expand_prompt
|
||||||
|
@ -271,7 +270,7 @@ def load_pipeline(
|
||||||
text_encoder
|
text_encoder
|
||||||
)
|
)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
text_encoder_opts = SessionOptions()
|
text_encoder_opts = device.sess_options(cache=False)
|
||||||
text_encoder_opts.add_external_initializers(
|
text_encoder_opts.add_external_initializers(
|
||||||
list(text_encoder_names), list(text_encoder_values)
|
list(text_encoder_names), list(text_encoder_values)
|
||||||
)
|
)
|
||||||
|
@ -292,7 +291,7 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = SessionOptions()
|
unet_opts = device.sess_options(cache=False)
|
||||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
components["unet"] = OnnxRuntimeModel(
|
components["unet"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
|
|
|
@ -113,8 +113,8 @@ class DeviceParams:
|
||||||
else:
|
else:
|
||||||
return (self.provider, self.options)
|
return (self.provider, self.options)
|
||||||
|
|
||||||
def sess_options(self) -> SessionOptions:
|
def sess_options(self, cache = True) -> SessionOptions:
|
||||||
if self.sess_options_cache is not None:
|
if cache and self.sess_options_cache is not None:
|
||||||
return self.sess_options_cache
|
return self.sess_options_cache
|
||||||
|
|
||||||
sess = SessionOptions()
|
sess = SessionOptions()
|
||||||
|
@ -139,7 +139,9 @@ class DeviceParams:
|
||||||
logger.debug("enabling ONNX deterministic compute")
|
logger.debug("enabling ONNX deterministic compute")
|
||||||
sess.use_deterministic_compute = True
|
sess.use_deterministic_compute = True
|
||||||
|
|
||||||
|
if cache:
|
||||||
self.sess_options_cache = sess
|
self.sess_options_cache = sess
|
||||||
|
|
||||||
return sess
|
return sess
|
||||||
|
|
||||||
def torch_str(self) -> str:
|
def torch_str(self) -> str:
|
||||||
|
|
Loading…
Reference in New Issue