fix(api): ensure VAE is loaded on correct device
This commit is contained in:
parent
fd8b9bef3b
commit
d4b013068d
|
@ -7,7 +7,12 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
|||
ORTStableDiffusionXLImg2ImgPipeline,
|
||||
ORTStableDiffusionXLPipeline,
|
||||
)
|
||||
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
|
||||
from optimum.onnxruntime.modeling_diffusion import (
|
||||
ORTModelTextEncoder,
|
||||
ORTModelUnet,
|
||||
ORTModelVaeDecoder,
|
||||
ORTModelVaeEncoder,
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ..constants import ONNX_MODEL
|
||||
|
@ -363,26 +368,40 @@ def load_pipeline(
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
elif (
|
||||
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
|
||||
):
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
components["vae_decoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
||||
if params.is_xl():
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
||||
vae_decoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||
components["vae_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||
vae_encoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
components["vae_decoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
vae_decoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||
components["vae_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
vae_encoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
||||
# additional options for panorama pipeline
|
||||
if params.is_panorama():
|
||||
|
@ -402,33 +421,30 @@ def load_pipeline(
|
|||
|
||||
# make sure XL models are actually being used
|
||||
if "text_encoder_session" in components:
|
||||
logger.info(
|
||||
"text encoder matches: %s, %s",
|
||||
pipe.text_encoder.session == components["text_encoder_session"],
|
||||
type(pipe.text_encoder),
|
||||
)
|
||||
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
||||
|
||||
if "text_encoder_2_session" in components:
|
||||
logger.info(
|
||||
"text encoder 2 matches: %s, %s",
|
||||
pipe.text_encoder_2.session == components["text_encoder_2_session"],
|
||||
type(pipe.text_encoder_2),
|
||||
)
|
||||
pipe.text_encoder_2 = ORTModelTextEncoder(
|
||||
text_encoder_2_session, text_encoder_2
|
||||
)
|
||||
|
||||
if "unet_session" in components:
|
||||
logger.info(
|
||||
"unet matches: %s, %s",
|
||||
pipe.unet.session == components["unet_session"],
|
||||
type(pipe.unet),
|
||||
)
|
||||
# unload old UNet first
|
||||
pipe.unet = None
|
||||
run_gc([device])
|
||||
# load correct one
|
||||
pipe.unet = ORTModelUnet(unet_session, unet_model)
|
||||
|
||||
if "vae_decoder_session" in components:
|
||||
pipe.vae_decoder = ORTModelVaeDecoder(
|
||||
components["vae_decoder_session"], vae_decoder
|
||||
)
|
||||
|
||||
if "vae_encoder_session" in components:
|
||||
pipe.vae_encoder = ORTModelVaeEncoder(
|
||||
components["vae_encoder_session"], vae_encoder
|
||||
)
|
||||
|
||||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue