fix(api): ensure VAE is loaded on correct device
This commit is contained in:
parent
fd8b9bef3b
commit
d4b013068d
|
@ -7,7 +7,12 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
||||||
ORTStableDiffusionXLImg2ImgPipeline,
|
ORTStableDiffusionXLImg2ImgPipeline,
|
||||||
ORTStableDiffusionXLPipeline,
|
ORTStableDiffusionXLPipeline,
|
||||||
)
|
)
|
||||||
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
|
from optimum.onnxruntime.modeling_diffusion import (
|
||||||
|
ORTModelTextEncoder,
|
||||||
|
ORTModelUnet,
|
||||||
|
ORTModelVaeDecoder,
|
||||||
|
ORTModelVaeEncoder,
|
||||||
|
)
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from ..constants import ONNX_MODEL
|
from ..constants import ONNX_MODEL
|
||||||
|
@ -363,26 +368,40 @@ def load_pipeline(
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif (
|
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
||||||
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
|
if params.is_xl():
|
||||||
):
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
components["vae_decoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
vae_decoder,
|
vae_decoder,
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
components["vae_encoder"] = OnnxRuntimeModel(
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
vae_encoder,
|
vae_encoder,
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
else:
|
||||||
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
|
components["vae_decoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
vae_decoder,
|
||||||
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
|
components["vae_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
vae_encoder,
|
||||||
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# additional options for panorama pipeline
|
# additional options for panorama pipeline
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
|
@ -402,33 +421,30 @@ def load_pipeline(
|
||||||
|
|
||||||
# make sure XL models are actually being used
|
# make sure XL models are actually being used
|
||||||
if "text_encoder_session" in components:
|
if "text_encoder_session" in components:
|
||||||
logger.info(
|
|
||||||
"text encoder matches: %s, %s",
|
|
||||||
pipe.text_encoder.session == components["text_encoder_session"],
|
|
||||||
type(pipe.text_encoder),
|
|
||||||
)
|
|
||||||
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
||||||
|
|
||||||
if "text_encoder_2_session" in components:
|
if "text_encoder_2_session" in components:
|
||||||
logger.info(
|
|
||||||
"text encoder 2 matches: %s, %s",
|
|
||||||
pipe.text_encoder_2.session == components["text_encoder_2_session"],
|
|
||||||
type(pipe.text_encoder_2),
|
|
||||||
)
|
|
||||||
pipe.text_encoder_2 = ORTModelTextEncoder(
|
pipe.text_encoder_2 = ORTModelTextEncoder(
|
||||||
text_encoder_2_session, text_encoder_2
|
text_encoder_2_session, text_encoder_2
|
||||||
)
|
)
|
||||||
|
|
||||||
if "unet_session" in components:
|
if "unet_session" in components:
|
||||||
logger.info(
|
# unload old UNet first
|
||||||
"unet matches: %s, %s",
|
|
||||||
pipe.unet.session == components["unet_session"],
|
|
||||||
type(pipe.unet),
|
|
||||||
)
|
|
||||||
pipe.unet = None
|
pipe.unet = None
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
# load correct one
|
||||||
pipe.unet = ORTModelUnet(unet_session, unet_model)
|
pipe.unet = ORTModelUnet(unet_session, unet_model)
|
||||||
|
|
||||||
|
if "vae_decoder_session" in components:
|
||||||
|
pipe.vae_decoder = ORTModelVaeDecoder(
|
||||||
|
components["vae_decoder_session"], vae_decoder
|
||||||
|
)
|
||||||
|
|
||||||
|
if "vae_encoder_session" in components:
|
||||||
|
pipe.vae_encoder = ORTModelVaeEncoder(
|
||||||
|
components["vae_encoder_session"], vae_encoder
|
||||||
|
)
|
||||||
|
|
||||||
if not server.show_progress:
|
if not server.show_progress:
|
||||||
pipe.set_progress_bar_config(disable=True)
|
pipe.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue