1
0
Fork 0

lint(api): start breaking down model loading

This commit is contained in:
Sean Sube 2023-09-23 20:11:05 -05:00
parent 38d3999088
commit 6b6f63564e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 280 additions and 237 deletions

View File

@ -106,6 +106,9 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
return None return None
VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"]
def load_pipeline( def load_pipeline(
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
@ -177,237 +180,28 @@ def load_pipeline(
} }
# shared components # shared components
text_encoder = None
unet_type = "unet" unet_type = "unet"
# ControlNet component # ControlNet component
if params.is_control() and params.control is not None: if params.is_control() and params.control is not None:
cnet_path = path.join( logger.debug("loading ControlNet components")
server.model_path, "control", f"{params.control.name}.onnx" control_components = load_controlnet(server, device, params)
) components.update(control_components)
logger.debug("loading ControlNet weights from %s", cnet_path)
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
cnet_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
)
unet_type = "cnet" unet_type = "cnet"
# Textual Inversion blending # Textual Inversion blending
if inversions is not None and len(inversions) > 0: encoder_components = load_text_encoders(
logger.debug("blending Textual Inversions from %s", inversions) server, device, model, inversions, loras, torch_dtype, params
inversion_names, inversion_weights = zip(*inversions) )
components.update(encoder_components)
inversion_models = [ unet_components = load_unet(
path.join(server.model_path, "inversion", name) server, device, model, loras, unet_type, params
for name in inversion_names )
] components.update(unet_components)
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
torch_dtype=torch_dtype,
)
text_encoder, tokenizer = blend_textual_inversions(
server,
text_encoder,
tokenizer,
list(
zip(
inversion_models,
inversion_weights,
inversion_names,
[None] * len(inversion_models),
)
),
)
components["tokenizer"] = tokenizer vae_components = load_vae(server, device, model)
components.update(vae_components)
# should be pretty small and should not need external data
if loras is None or len(loras) == 0:
# TODO: handle XL encoders
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(),
)
)
# LoRA blending
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info(
"blending base model %s with LoRA models: %s", model, lora_models
)
# blend and load text encoder
text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
text_encoder = blend_loras(
server,
text_encoder,
list(zip(lora_models, lora_weights)),
"text_encoder",
1 if params.is_xl() else None,
params.is_xl(),
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
if params.is_xl():
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
)
)
if params.is_xl():
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2 = blend_loras(
server,
text_encoder_2,
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
params.is_xl(),
)
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
text_encoder_2
)
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_opts.add_external_initializers(
list(text_encoder_2_names), list(text_encoder_2_values)
)
text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session
# blend and load unet
unet = path.join(model, unet_type, ONNX_MODEL)
blended_unet = blend_loras(
server,
unet,
list(zip(lora_models, lora_weights)),
"unet",
xl=params.is_xl(),
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
if params.is_xl():
unet_session = InferenceSession(
unet_model.SerializeToString(),
providers=[device.ort_provider("unet")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet_session"] = unet_session
else:
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider("unet"),
sess_options=unet_opts,
)
)
# make sure a UNet has been loaded
if not params.is_xl() and "unet" not in components:
unet = path.join(model, unet_type, ONNX_MODEL)
logger.debug("loading UNet (%s) from %s", unet_type, unet)
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet,
provider=device.ort_provider("unet"),
sess_options=device.sess_options(),
)
)
# one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL)
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
if not params.is_xl() and path.exists(vae):
logger.debug("loading VAE from %s", vae)
components["vae"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
elif path.exists(vae_decoder) and path.exists(vae_encoder):
if params.is_xl():
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components[
"vae_decoder_session"
]._model_path = vae_decoder # "#\\not a real path on any system"
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components[
"vae_encoder_session"
]._model_path = vae_encoder # "#\\not a real path on any system"
else:
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
# additional options for panorama pipeline # additional options for panorama pipeline
if params.is_panorama(): if params.is_panorama():
@ -427,19 +221,22 @@ def load_pipeline(
# make sure XL models are actually being used # make sure XL models are actually being used
if "text_encoder_session" in components: if "text_encoder_session" in components:
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder) pipe.text_encoder = ORTModelTextEncoder(
components["text_encoder_session"], pipe
)
if "text_encoder_2_session" in components: if "text_encoder_2_session" in components:
pipe.text_encoder_2 = ORTModelTextEncoder( pipe.text_encoder_2 = ORTModelTextEncoder(
text_encoder_2_session, text_encoder_2 components["text_encoder_2_session"], pipe
) )
if "unet_session" in components: if "unet_session" in components:
# unload old UNet first # unload old UNet
pipe.unet = None pipe.unet = None
run_gc([device]) run_gc([device])
# load correct one
pipe.unet = ORTModelUnet(unet_session, unet_model) # attach correct one
pipe.unet = ORTModelUnet(components["unet_session"], pipe)
if "vae_decoder_session" in components: if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder( pipe.vae_decoder = ORTModelVaeDecoder(
@ -462,11 +259,9 @@ def load_pipeline(
server.cache.set(ModelTypes.diffusion, pipe_key, pipe) server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"]) server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
if hasattr(pipe, "vae_decoder"): for vae in VAE_COMPONENTS:
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae) if hasattr(pipe, vae):
getattr(pipe, vae).set_tiled(tiled=params.tiled_vae)
if hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
# update panorama params # update panorama params
if params.is_panorama(): if params.is_panorama():
@ -474,16 +269,262 @@ def load_pipeline(
latent_stride = params.stride // 8 latent_stride = params.stride // 8
pipe.set_window_size(latent_window, latent_stride) pipe.set_window_size(latent_window, latent_stride)
if hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_window_size(latent_window, params.overlap) for vae in VAE_COMPONENTS:
if hasattr(pipe, "vae_encoder"): if hasattr(pipe, vae):
pipe.vae_encoder.set_window_size(latent_window, params.overlap) getattr(pipe, vae).set_window_size(latent_window, params.overlap)
run_gc([device]) run_gc([device])
return pipe return pipe
def load_controlnet(server, device, params):
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
logger.debug("loading ControlNet weights from %s", cnet_path)
components = {}
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
cnet_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
)
return components
def load_text_encoders(
server, device, model: str, inversions, loras, torch_dtype, params
):
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
torch_dtype=torch_dtype,
)
components = {}
components["tokenizer"] = tokenizer
if inversions is not None and len(inversions) > 0:
logger.debug("blending Textual Inversions from %s", inversions)
inversion_names, inversion_weights = zip(*inversions)
inversion_models = [
path.join(server.model_path, "inversion", name) for name in inversion_names
]
text_encoder, tokenizer = blend_textual_inversions(
server,
text_encoder,
tokenizer,
list(
zip(
inversion_models,
inversion_weights,
inversion_names,
[None] * len(inversion_models),
)
),
)
# should be pretty small and should not need external data
if loras is None or len(loras) == 0:
# TODO: handle XL encoders
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(),
)
)
else:
# blend and load text encoder
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
text_encoder = blend_loras(
server,
text_encoder,
list(zip(lora_models, lora_weights)),
"text_encoder",
1 if params.is_xl() else None,
params.is_xl(),
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
if params.is_xl():
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
)
)
if params.is_xl():
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2 = blend_loras(
server,
text_encoder_2,
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
params.is_xl(),
)
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
text_encoder_2
)
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_opts.add_external_initializers(
list(text_encoder_2_names), list(text_encoder_2_values)
)
text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session
return components
def load_unet(server, device, model, loras, unet_type, params):
components = {}
unet = path.join(model, unet_type, ONNX_MODEL)
# LoRA blending
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
# blend and load unet
blended_unet = blend_loras(
server,
unet,
list(zip(lora_models, lora_weights)),
"unet",
xl=params.is_xl(),
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
if params.is_xl():
unet_session = InferenceSession(
unet_model.SerializeToString(),
providers=[device.ort_provider("unet")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet_session"] = unet_session
else:
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider("unet"),
sess_options=unet_opts,
)
)
# make sure a UNet has been loaded
if not params.is_xl() and "unet" not in components:
unet = path.join(model, unet_type, ONNX_MODEL)
logger.debug("loading UNet (%s) from %s", unet_type, unet)
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet,
provider=device.ort_provider("unet"),
sess_options=device.sess_options(),
)
)
return components
def load_vae(server, device, model, params):
# one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL)
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
components = {}
if not params.is_xl() and path.exists(vae):
logger.debug("loading VAE from %s", vae)
components["vae"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
elif path.exists(vae_decoder) and path.exists(vae_encoder):
if params.is_xl():
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components[
"vae_decoder_session"
]._model_path = vae_decoder
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components[
"vae_encoder_session"
]._model_path = vae_encoder
else:
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
return components
def optimize_pipeline( def optimize_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,

View File

@ -11,7 +11,6 @@ from onnx_web.diffusers.load import (
) )
from onnx_web.diffusers.patches.unet import UNetWrapper from onnx_web.diffusers.patches.unet import UNetWrapper
from onnx_web.diffusers.patches.vae import VAEWrapper from onnx_web.diffusers.patches.vae import VAEWrapper
from onnx_web.diffusers.utils import expand_prompt
from onnx_web.params import ImageParams from onnx_web.params import ImageParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from tests.mocks import MockPipeline from tests.mocks import MockPipeline

View File

@ -26,6 +26,7 @@
"bokeh", "bokeh",
"Civitai", "Civitai",
"ckpt", "ckpt",
"cnet",
"codebook", "codebook",
"codeformer", "codeformer",
"controlnet", "controlnet",
@ -53,6 +54,8 @@
"KDPM", "KDPM",
"Knollingcase", "Knollingcase",
"Lanczos", "Lanczos",
"loha",
"loras",
"Multistep", "Multistep",
"ndarray", "ndarray",
"numpy", "numpy",