1
0
Fork 0

fix(api): ensure VAE is loaded on correct device

This commit is contained in:
Sean Sube 2023-09-12 07:21:35 -05:00
parent fd8b9bef3b
commit d4b013068d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 43 additions and 27 deletions

View File

@ -7,7 +7,12 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
)
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
from optimum.onnxruntime.modeling_diffusion import (
ORTModelTextEncoder,
ORTModelUnet,
ORTModelVaeDecoder,
ORTModelVaeEncoder,
)
from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL
@ -363,26 +368,40 @@ def load_pipeline(
sess_options=device.sess_options(),
)
)
elif (
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
):
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
elif path.exists(vae_decoder) and path.exists(vae_encoder):
if params.is_xl():
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
else:
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
# additional options for panorama pipeline
if params.is_panorama():
@ -402,33 +421,30 @@ def load_pipeline(
# make sure XL models are actually being used
if "text_encoder_session" in components:
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder.session == components["text_encoder_session"],
type(pipe.text_encoder),
)
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
if "text_encoder_2_session" in components:
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2.session == components["text_encoder_2_session"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = ORTModelTextEncoder(
text_encoder_2_session, text_encoder_2
)
if "unet_session" in components:
logger.info(
"unet matches: %s, %s",
pipe.unet.session == components["unet_session"],
type(pipe.unet),
)
# unload old UNet first
pipe.unet = None
run_gc([device])
# load correct one
pipe.unet = ORTModelUnet(unet_session, unet_model)
if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder(
components["vae_decoder_session"], vae_decoder
)
if "vae_encoder_session" in components:
pipe.vae_encoder = ORTModelVaeEncoder(
components["vae_encoder_session"], vae_encoder
)
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)