lint(api): start breaking down model loading
This commit is contained in:
parent
38d3999088
commit
6b6f63564e
|
@ -106,6 +106,9 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"]
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(
|
def load_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -177,237 +180,28 @@ def load_pipeline(
|
||||||
}
|
}
|
||||||
|
|
||||||
# shared components
|
# shared components
|
||||||
text_encoder = None
|
|
||||||
unet_type = "unet"
|
unet_type = "unet"
|
||||||
|
|
||||||
# ControlNet component
|
# ControlNet component
|
||||||
if params.is_control() and params.control is not None:
|
if params.is_control() and params.control is not None:
|
||||||
cnet_path = path.join(
|
logger.debug("loading ControlNet components")
|
||||||
server.model_path, "control", f"{params.control.name}.onnx"
|
control_components = load_controlnet(server, device, params)
|
||||||
)
|
components.update(control_components)
|
||||||
logger.debug("loading ControlNet weights from %s", cnet_path)
|
|
||||||
components["controlnet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
cnet_path,
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
unet_type = "cnet"
|
unet_type = "cnet"
|
||||||
|
|
||||||
# Textual Inversion blending
|
# Textual Inversion blending
|
||||||
if inversions is not None and len(inversions) > 0:
|
encoder_components = load_text_encoders(
|
||||||
logger.debug("blending Textual Inversions from %s", inversions)
|
server, device, model, inversions, loras, torch_dtype, params
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
)
|
||||||
|
components.update(encoder_components)
|
||||||
|
|
||||||
inversion_models = [
|
unet_components = load_unet(
|
||||||
path.join(server.model_path, "inversion", name)
|
server, device, model, loras, unet_type, params
|
||||||
for name in inversion_names
|
)
|
||||||
]
|
components.update(unet_components)
|
||||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
|
||||||
model,
|
|
||||||
subfolder="tokenizer",
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
text_encoder, tokenizer = blend_textual_inversions(
|
|
||||||
server,
|
|
||||||
text_encoder,
|
|
||||||
tokenizer,
|
|
||||||
list(
|
|
||||||
zip(
|
|
||||||
inversion_models,
|
|
||||||
inversion_weights,
|
|
||||||
inversion_names,
|
|
||||||
[None] * len(inversion_models),
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
components["tokenizer"] = tokenizer
|
vae_components = load_vae(server, device, model)
|
||||||
|
components.update(vae_components)
|
||||||
# should be pretty small and should not need external data
|
|
||||||
if loras is None or len(loras) == 0:
|
|
||||||
# TODO: handle XL encoders
|
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
text_encoder.SerializeToString(),
|
|
||||||
provider=device.ort_provider("text-encoder"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA blending
|
|
||||||
if loras is not None and len(loras) > 0:
|
|
||||||
lora_names, lora_weights = zip(*loras)
|
|
||||||
lora_models = [
|
|
||||||
path.join(server.model_path, "lora", name) for name in lora_names
|
|
||||||
]
|
|
||||||
logger.info(
|
|
||||||
"blending base model %s with LoRA models: %s", model, lora_models
|
|
||||||
)
|
|
||||||
|
|
||||||
# blend and load text encoder
|
|
||||||
text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
|
|
||||||
text_encoder = blend_loras(
|
|
||||||
server,
|
|
||||||
text_encoder,
|
|
||||||
list(zip(lora_models, lora_weights)),
|
|
||||||
"text_encoder",
|
|
||||||
1 if params.is_xl() else None,
|
|
||||||
params.is_xl(),
|
|
||||||
)
|
|
||||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
|
||||||
text_encoder
|
|
||||||
)
|
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
|
||||||
text_encoder_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_opts.add_external_initializers(
|
|
||||||
list(text_encoder_names), list(text_encoder_values)
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.is_xl():
|
|
||||||
text_encoder_session = InferenceSession(
|
|
||||||
text_encoder.SerializeToString(),
|
|
||||||
providers=[device.ort_provider("text-encoder")],
|
|
||||||
sess_options=text_encoder_opts,
|
|
||||||
)
|
|
||||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
|
||||||
components["text_encoder_session"] = text_encoder_session
|
|
||||||
else:
|
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
text_encoder.SerializeToString(),
|
|
||||||
provider=device.ort_provider("text-encoder"),
|
|
||||||
sess_options=text_encoder_opts,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.is_xl():
|
|
||||||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
|
||||||
text_encoder_2 = blend_loras(
|
|
||||||
server,
|
|
||||||
text_encoder_2,
|
|
||||||
list(zip(lora_models, lora_weights)),
|
|
||||||
"text_encoder",
|
|
||||||
2,
|
|
||||||
params.is_xl(),
|
|
||||||
)
|
|
||||||
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
|
||||||
text_encoder_2
|
|
||||||
)
|
|
||||||
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
|
||||||
text_encoder_2_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_2_opts.add_external_initializers(
|
|
||||||
list(text_encoder_2_names), list(text_encoder_2_values)
|
|
||||||
)
|
|
||||||
|
|
||||||
text_encoder_2_session = InferenceSession(
|
|
||||||
text_encoder_2.SerializeToString(),
|
|
||||||
providers=[device.ort_provider("text-encoder")],
|
|
||||||
sess_options=text_encoder_2_opts,
|
|
||||||
)
|
|
||||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
|
||||||
components["text_encoder_2_session"] = text_encoder_2_session
|
|
||||||
|
|
||||||
# blend and load unet
|
|
||||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
|
||||||
blended_unet = blend_loras(
|
|
||||||
server,
|
|
||||||
unet,
|
|
||||||
list(zip(lora_models, lora_weights)),
|
|
||||||
"unet",
|
|
||||||
xl=params.is_xl(),
|
|
||||||
)
|
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
|
||||||
unet_names, unet_values = zip(*unet_data)
|
|
||||||
unet_opts = device.sess_options(cache=False)
|
|
||||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
|
||||||
|
|
||||||
if params.is_xl():
|
|
||||||
unet_session = InferenceSession(
|
|
||||||
unet_model.SerializeToString(),
|
|
||||||
providers=[device.ort_provider("unet")],
|
|
||||||
sess_options=unet_opts,
|
|
||||||
)
|
|
||||||
unet_session._model_path = path.join(model, "unet")
|
|
||||||
components["unet_session"] = unet_session
|
|
||||||
else:
|
|
||||||
components["unet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
unet_model.SerializeToString(),
|
|
||||||
provider=device.ort_provider("unet"),
|
|
||||||
sess_options=unet_opts,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure a UNet has been loaded
|
|
||||||
if not params.is_xl() and "unet" not in components:
|
|
||||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
|
||||||
logger.debug("loading UNet (%s) from %s", unet_type, unet)
|
|
||||||
components["unet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
unet,
|
|
||||||
provider=device.ort_provider("unet"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# one or more VAE models need to be loaded
|
|
||||||
vae = path.join(model, "vae", ONNX_MODEL)
|
|
||||||
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
|
||||||
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
|
|
||||||
|
|
||||||
if not params.is_xl() and path.exists(vae):
|
|
||||||
logger.debug("loading VAE from %s", vae)
|
|
||||||
components["vae"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
vae,
|
|
||||||
provider=device.ort_provider("vae"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
|
||||||
if params.is_xl():
|
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
|
||||||
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
|
||||||
vae_decoder,
|
|
||||||
provider=device.ort_provider("vae"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
components[
|
|
||||||
"vae_decoder_session"
|
|
||||||
]._model_path = vae_decoder # "#\\not a real path on any system"
|
|
||||||
|
|
||||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
|
||||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
|
||||||
vae_encoder,
|
|
||||||
provider=device.ort_provider("vae"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
components[
|
|
||||||
"vae_encoder_session"
|
|
||||||
]._model_path = vae_encoder # "#\\not a real path on any system"
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
|
||||||
components["vae_decoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
vae_decoder,
|
|
||||||
provider=device.ort_provider("vae"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
|
||||||
components["vae_encoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
vae_encoder,
|
|
||||||
provider=device.ort_provider("vae"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# additional options for panorama pipeline
|
# additional options for panorama pipeline
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
|
@ -427,19 +221,22 @@ def load_pipeline(
|
||||||
|
|
||||||
# make sure XL models are actually being used
|
# make sure XL models are actually being used
|
||||||
if "text_encoder_session" in components:
|
if "text_encoder_session" in components:
|
||||||
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
pipe.text_encoder = ORTModelTextEncoder(
|
||||||
|
components["text_encoder_session"], pipe
|
||||||
|
)
|
||||||
|
|
||||||
if "text_encoder_2_session" in components:
|
if "text_encoder_2_session" in components:
|
||||||
pipe.text_encoder_2 = ORTModelTextEncoder(
|
pipe.text_encoder_2 = ORTModelTextEncoder(
|
||||||
text_encoder_2_session, text_encoder_2
|
components["text_encoder_2_session"], pipe
|
||||||
)
|
)
|
||||||
|
|
||||||
if "unet_session" in components:
|
if "unet_session" in components:
|
||||||
# unload old UNet first
|
# unload old UNet
|
||||||
pipe.unet = None
|
pipe.unet = None
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
# load correct one
|
|
||||||
pipe.unet = ORTModelUnet(unet_session, unet_model)
|
# attach correct one
|
||||||
|
pipe.unet = ORTModelUnet(components["unet_session"], pipe)
|
||||||
|
|
||||||
if "vae_decoder_session" in components:
|
if "vae_decoder_session" in components:
|
||||||
pipe.vae_decoder = ORTModelVaeDecoder(
|
pipe.vae_decoder = ORTModelVaeDecoder(
|
||||||
|
@ -462,11 +259,9 @@ def load_pipeline(
|
||||||
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||||
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
||||||
|
|
||||||
if hasattr(pipe, "vae_decoder"):
|
for vae in VAE_COMPONENTS:
|
||||||
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
if hasattr(pipe, vae):
|
||||||
|
getattr(pipe, vae).set_tiled(tiled=params.tiled_vae)
|
||||||
if hasattr(pipe, "vae_encoder"):
|
|
||||||
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
|
|
||||||
|
|
||||||
# update panorama params
|
# update panorama params
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
|
@ -474,16 +269,262 @@ def load_pipeline(
|
||||||
latent_stride = params.stride // 8
|
latent_stride = params.stride // 8
|
||||||
|
|
||||||
pipe.set_window_size(latent_window, latent_stride)
|
pipe.set_window_size(latent_window, latent_stride)
|
||||||
if hasattr(pipe, "vae_decoder"):
|
|
||||||
pipe.vae_decoder.set_window_size(latent_window, params.overlap)
|
for vae in VAE_COMPONENTS:
|
||||||
if hasattr(pipe, "vae_encoder"):
|
if hasattr(pipe, vae):
|
||||||
pipe.vae_encoder.set_window_size(latent_window, params.overlap)
|
getattr(pipe, vae).set_window_size(latent_window, params.overlap)
|
||||||
|
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet(server, device, params):
|
||||||
|
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
|
||||||
|
logger.debug("loading ControlNet weights from %s", cnet_path)
|
||||||
|
components = {}
|
||||||
|
components["controlnet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
cnet_path,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
def load_text_encoders(
|
||||||
|
server, device, model: str, inversions, loras, torch_dtype, params
|
||||||
|
):
|
||||||
|
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
|
model,
|
||||||
|
subfolder="tokenizer",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
components = {}
|
||||||
|
components["tokenizer"] = tokenizer
|
||||||
|
|
||||||
|
if inversions is not None and len(inversions) > 0:
|
||||||
|
logger.debug("blending Textual Inversions from %s", inversions)
|
||||||
|
inversion_names, inversion_weights = zip(*inversions)
|
||||||
|
|
||||||
|
inversion_models = [
|
||||||
|
path.join(server.model_path, "inversion", name) for name in inversion_names
|
||||||
|
]
|
||||||
|
|
||||||
|
text_encoder, tokenizer = blend_textual_inversions(
|
||||||
|
server,
|
||||||
|
text_encoder,
|
||||||
|
tokenizer,
|
||||||
|
list(
|
||||||
|
zip(
|
||||||
|
inversion_models,
|
||||||
|
inversion_weights,
|
||||||
|
inversion_names,
|
||||||
|
[None] * len(inversion_models),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# should be pretty small and should not need external data
|
||||||
|
if loras is None or len(loras) == 0:
|
||||||
|
# TODO: handle XL encoders
|
||||||
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
provider=device.ort_provider("text-encoder"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# blend and load text encoder
|
||||||
|
lora_names, lora_weights = zip(*loras)
|
||||||
|
lora_models = [
|
||||||
|
path.join(server.model_path, "lora", name) for name in lora_names
|
||||||
|
]
|
||||||
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
|
text_encoder = blend_loras(
|
||||||
|
server,
|
||||||
|
text_encoder,
|
||||||
|
list(zip(lora_models, lora_weights)),
|
||||||
|
"text_encoder",
|
||||||
|
1 if params.is_xl() else None,
|
||||||
|
params.is_xl(),
|
||||||
|
)
|
||||||
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
|
||||||
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
|
text_encoder_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder_opts.add_external_initializers(
|
||||||
|
list(text_encoder_names), list(text_encoder_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder_session = InferenceSession(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
|
)
|
||||||
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||||
|
components["text_encoder_session"] = text_encoder_session
|
||||||
|
else:
|
||||||
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
provider=device.ort_provider("text-encoder"),
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||||
|
text_encoder_2 = blend_loras(
|
||||||
|
server,
|
||||||
|
text_encoder_2,
|
||||||
|
list(zip(lora_models, lora_weights)),
|
||||||
|
"text_encoder",
|
||||||
|
2,
|
||||||
|
params.is_xl(),
|
||||||
|
)
|
||||||
|
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
||||||
|
text_encoder_2
|
||||||
|
)
|
||||||
|
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
||||||
|
text_encoder_2_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder_2_opts.add_external_initializers(
|
||||||
|
list(text_encoder_2_names), list(text_encoder_2_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder_2_session = InferenceSession(
|
||||||
|
text_encoder_2.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_2_opts,
|
||||||
|
)
|
||||||
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
|
components["text_encoder_2_session"] = text_encoder_2_session
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
def load_unet(server, device, model, loras, unet_type, params):
|
||||||
|
components = {}
|
||||||
|
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||||
|
|
||||||
|
# LoRA blending
|
||||||
|
if loras is not None and len(loras) > 0:
|
||||||
|
lora_names, lora_weights = zip(*loras)
|
||||||
|
lora_models = [
|
||||||
|
path.join(server.model_path, "lora", name) for name in lora_names
|
||||||
|
]
|
||||||
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
|
# blend and load unet
|
||||||
|
blended_unet = blend_loras(
|
||||||
|
server,
|
||||||
|
unet,
|
||||||
|
list(zip(lora_models, lora_weights)),
|
||||||
|
"unet",
|
||||||
|
xl=params.is_xl(),
|
||||||
|
)
|
||||||
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
|
unet_names, unet_values = zip(*unet_data)
|
||||||
|
unet_opts = device.sess_options(cache=False)
|
||||||
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
unet_session = InferenceSession(
|
||||||
|
unet_model.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("unet")],
|
||||||
|
sess_options=unet_opts,
|
||||||
|
)
|
||||||
|
unet_session._model_path = path.join(model, "unet")
|
||||||
|
components["unet_session"] = unet_session
|
||||||
|
else:
|
||||||
|
components["unet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
unet_model.SerializeToString(),
|
||||||
|
provider=device.ort_provider("unet"),
|
||||||
|
sess_options=unet_opts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure a UNet has been loaded
|
||||||
|
if not params.is_xl() and "unet" not in components:
|
||||||
|
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||||
|
logger.debug("loading UNet (%s) from %s", unet_type, unet)
|
||||||
|
components["unet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
unet,
|
||||||
|
provider=device.ort_provider("unet"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae(server, device, model, params):
|
||||||
|
# one or more VAE models need to be loaded
|
||||||
|
vae = path.join(model, "vae", ONNX_MODEL)
|
||||||
|
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
||||||
|
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
|
||||||
|
|
||||||
|
components = {}
|
||||||
|
if not params.is_xl() and path.exists(vae):
|
||||||
|
logger.debug("loading VAE from %s", vae)
|
||||||
|
components["vae"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
vae,
|
||||||
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
||||||
|
if params.is_xl():
|
||||||
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
|
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
|
vae_decoder,
|
||||||
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
components[
|
||||||
|
"vae_decoder_session"
|
||||||
|
]._model_path = vae_decoder
|
||||||
|
|
||||||
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
|
vae_encoder,
|
||||||
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
components[
|
||||||
|
"vae_encoder_session"
|
||||||
|
]._model_path = vae_encoder
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
|
components["vae_decoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
vae_decoder,
|
||||||
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
|
components["vae_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
vae_encoder,
|
||||||
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
def optimize_pipeline(
|
def optimize_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
|
|
|
@ -11,7 +11,6 @@ from onnx_web.diffusers.load import (
|
||||||
)
|
)
|
||||||
from onnx_web.diffusers.patches.unet import UNetWrapper
|
from onnx_web.diffusers.patches.unet import UNetWrapper
|
||||||
from onnx_web.diffusers.patches.vae import VAEWrapper
|
from onnx_web.diffusers.patches.vae import VAEWrapper
|
||||||
from onnx_web.diffusers.utils import expand_prompt
|
|
||||||
from onnx_web.params import ImageParams
|
from onnx_web.params import ImageParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from tests.mocks import MockPipeline
|
from tests.mocks import MockPipeline
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
"bokeh",
|
"bokeh",
|
||||||
"Civitai",
|
"Civitai",
|
||||||
"ckpt",
|
"ckpt",
|
||||||
|
"cnet",
|
||||||
"codebook",
|
"codebook",
|
||||||
"codeformer",
|
"codeformer",
|
||||||
"controlnet",
|
"controlnet",
|
||||||
|
@ -53,6 +54,8 @@
|
||||||
"KDPM",
|
"KDPM",
|
||||||
"Knollingcase",
|
"Knollingcase",
|
||||||
"Lanczos",
|
"Lanczos",
|
||||||
|
"loha",
|
||||||
|
"loras",
|
||||||
"Multistep",
|
"Multistep",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
|
Loading…
Reference in New Issue