diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index c38fe487..4e4c8c90 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -106,6 +106,9 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]: return None +VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"] + + def load_pipeline( server: ServerContext, params: ImageParams, @@ -177,237 +180,28 @@ def load_pipeline( } # shared components - text_encoder = None unet_type = "unet" # ControlNet component if params.is_control() and params.control is not None: - cnet_path = path.join( - server.model_path, "control", f"{params.control.name}.onnx" - ) - logger.debug("loading ControlNet weights from %s", cnet_path) - components["controlnet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - cnet_path, - provider=device.ort_provider(), - sess_options=device.sess_options(), - ) - ) - + logger.debug("loading ControlNet components") + control_components = load_controlnet(server, device, params) + components.update(control_components) unet_type = "cnet" # Textual Inversion blending - if inversions is not None and len(inversions) > 0: - logger.debug("blending Textual Inversions from %s", inversions) - inversion_names, inversion_weights = zip(*inversions) + encoder_components = load_text_encoders( + server, device, model, inversions, loras, torch_dtype, params + ) + components.update(encoder_components) - inversion_models = [ - path.join(server.model_path, "inversion", name) - for name in inversion_names - ] - text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) - tokenizer = CLIPTokenizer.from_pretrained( - model, - subfolder="tokenizer", - torch_dtype=torch_dtype, - ) - text_encoder, tokenizer = blend_textual_inversions( - server, - text_encoder, - tokenizer, - list( - zip( - inversion_models, - inversion_weights, - inversion_names, - [None] * len(inversion_models), - ) - ), - ) + unet_components = load_unet( + server, device, model, loras, unet_type, params + ) + components.update(unet_components) - components["tokenizer"] = tokenizer - - # should be pretty small and should not need external data - if loras is None or len(loras) == 0: - # TODO: handle XL encoders - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder.SerializeToString(), - provider=device.ort_provider("text-encoder"), - sess_options=device.sess_options(), - ) - ) - - # LoRA blending - if loras is not None and len(loras) > 0: - lora_names, lora_weights = zip(*loras) - lora_models = [ - path.join(server.model_path, "lora", name) for name in lora_names - ] - logger.info( - "blending base model %s with LoRA models: %s", model, lora_models - ) - - # blend and load text encoder - text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL) - text_encoder = blend_loras( - server, - text_encoder, - list(zip(lora_models, lora_weights)), - "text_encoder", - 1 if params.is_xl() else None, - params.is_xl(), - ) - (text_encoder, text_encoder_data) = buffer_external_data_tensors( - text_encoder - ) - text_encoder_names, text_encoder_values = zip(*text_encoder_data) - text_encoder_opts = device.sess_options(cache=False) - text_encoder_opts.add_external_initializers( - list(text_encoder_names), list(text_encoder_values) - ) - - if params.is_xl(): - text_encoder_session = InferenceSession( - text_encoder.SerializeToString(), - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_opts, - ) - text_encoder_session._model_path = path.join(model, "text_encoder") - components["text_encoder_session"] = text_encoder_session - else: - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder.SerializeToString(), - provider=device.ort_provider("text-encoder"), - sess_options=text_encoder_opts, - ) - ) - - if params.is_xl(): - text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) - text_encoder_2 = blend_loras( - server, - text_encoder_2, - list(zip(lora_models, lora_weights)), - "text_encoder", - 2, - params.is_xl(), - ) - (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( - text_encoder_2 - ) - text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) - text_encoder_2_opts = device.sess_options(cache=False) - text_encoder_2_opts.add_external_initializers( - list(text_encoder_2_names), list(text_encoder_2_values) - ) - - text_encoder_2_session = InferenceSession( - text_encoder_2.SerializeToString(), - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_2_opts, - ) - text_encoder_2_session._model_path = path.join(model, "text_encoder_2") - components["text_encoder_2_session"] = text_encoder_2_session - - # blend and load unet - unet = path.join(model, unet_type, ONNX_MODEL) - blended_unet = blend_loras( - server, - unet, - list(zip(lora_models, lora_weights)), - "unet", - xl=params.is_xl(), - ) - (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) - unet_names, unet_values = zip(*unet_data) - unet_opts = device.sess_options(cache=False) - unet_opts.add_external_initializers(list(unet_names), list(unet_values)) - - if params.is_xl(): - unet_session = InferenceSession( - unet_model.SerializeToString(), - providers=[device.ort_provider("unet")], - sess_options=unet_opts, - ) - unet_session._model_path = path.join(model, "unet") - components["unet_session"] = unet_session - else: - components["unet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - unet_model.SerializeToString(), - provider=device.ort_provider("unet"), - sess_options=unet_opts, - ) - ) - - # make sure a UNet has been loaded - if not params.is_xl() and "unet" not in components: - unet = path.join(model, unet_type, ONNX_MODEL) - logger.debug("loading UNet (%s) from %s", unet_type, unet) - components["unet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - unet, - provider=device.ort_provider("unet"), - sess_options=device.sess_options(), - ) - ) - - # one or more VAE models need to be loaded - vae = path.join(model, "vae", ONNX_MODEL) - vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) - vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL) - - if not params.is_xl() and path.exists(vae): - logger.debug("loading VAE from %s", vae) - components["vae"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - vae, - provider=device.ort_provider("vae"), - sess_options=device.sess_options(), - ) - ) - elif path.exists(vae_decoder) and path.exists(vae_encoder): - if params.is_xl(): - logger.debug("loading VAE decoder from %s", vae_decoder) - components["vae_decoder_session"] = OnnxRuntimeModel.load_model( - vae_decoder, - provider=device.ort_provider("vae"), - sess_options=device.sess_options(), - ) - components[ - "vae_decoder_session" - ]._model_path = vae_decoder # "#\\not a real path on any system" - - logger.debug("loading VAE encoder from %s", vae_encoder) - components["vae_encoder_session"] = OnnxRuntimeModel.load_model( - vae_encoder, - provider=device.ort_provider("vae"), - sess_options=device.sess_options(), - ) - components[ - "vae_encoder_session" - ]._model_path = vae_encoder # "#\\not a real path on any system" - - else: - logger.debug("loading VAE decoder from %s", vae_decoder) - components["vae_decoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - vae_decoder, - provider=device.ort_provider("vae"), - sess_options=device.sess_options(), - ) - ) - - logger.debug("loading VAE encoder from %s", vae_encoder) - components["vae_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - vae_encoder, - provider=device.ort_provider("vae"), - sess_options=device.sess_options(), - ) - ) + vae_components = load_vae(server, device, model) + components.update(vae_components) # additional options for panorama pipeline if params.is_panorama(): @@ -427,19 +221,22 @@ def load_pipeline( # make sure XL models are actually being used if "text_encoder_session" in components: - pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder) + pipe.text_encoder = ORTModelTextEncoder( + components["text_encoder_session"], pipe + ) if "text_encoder_2_session" in components: pipe.text_encoder_2 = ORTModelTextEncoder( - text_encoder_2_session, text_encoder_2 + components["text_encoder_2_session"], pipe ) if "unet_session" in components: - # unload old UNet first + # unload old UNet pipe.unet = None run_gc([device]) - # load correct one - pipe.unet = ORTModelUnet(unet_session, unet_model) + + # attach correct one + pipe.unet = ORTModelUnet(components["unet_session"], pipe) if "vae_decoder_session" in components: pipe.vae_decoder = ORTModelVaeDecoder( @@ -462,11 +259,9 @@ def load_pipeline( server.cache.set(ModelTypes.diffusion, pipe_key, pipe) server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"]) - if hasattr(pipe, "vae_decoder"): - pipe.vae_decoder.set_tiled(tiled=params.tiled_vae) - - if hasattr(pipe, "vae_encoder"): - pipe.vae_encoder.set_tiled(tiled=params.tiled_vae) + for vae in VAE_COMPONENTS: + if hasattr(pipe, vae): + getattr(pipe, vae).set_tiled(tiled=params.tiled_vae) # update panorama params if params.is_panorama(): @@ -474,16 +269,262 @@ def load_pipeline( latent_stride = params.stride // 8 pipe.set_window_size(latent_window, latent_stride) - if hasattr(pipe, "vae_decoder"): - pipe.vae_decoder.set_window_size(latent_window, params.overlap) - if hasattr(pipe, "vae_encoder"): - pipe.vae_encoder.set_window_size(latent_window, params.overlap) + + for vae in VAE_COMPONENTS: + if hasattr(pipe, vae): + getattr(pipe, vae).set_window_size(latent_window, params.overlap) run_gc([device]) return pipe +def load_controlnet(server, device, params): + cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx") + logger.debug("loading ControlNet weights from %s", cnet_path) + components = {} + components["controlnet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + cnet_path, + provider=device.ort_provider(), + sess_options=device.sess_options(), + ) + ) + return components + + +def load_text_encoders( + server, device, model: str, inversions, loras, torch_dtype, params +): + text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) + tokenizer = CLIPTokenizer.from_pretrained( + model, + subfolder="tokenizer", + torch_dtype=torch_dtype, + ) + + components = {} + components["tokenizer"] = tokenizer + + if inversions is not None and len(inversions) > 0: + logger.debug("blending Textual Inversions from %s", inversions) + inversion_names, inversion_weights = zip(*inversions) + + inversion_models = [ + path.join(server.model_path, "inversion", name) for name in inversion_names + ] + + text_encoder, tokenizer = blend_textual_inversions( + server, + text_encoder, + tokenizer, + list( + zip( + inversion_models, + inversion_weights, + inversion_names, + [None] * len(inversion_models), + ) + ), + ) + + # should be pretty small and should not need external data + if loras is None or len(loras) == 0: + # TODO: handle XL encoders + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder.SerializeToString(), + provider=device.ort_provider("text-encoder"), + sess_options=device.sess_options(), + ) + ) + else: + # blend and load text encoder + lora_names, lora_weights = zip(*loras) + lora_models = [ + path.join(server.model_path, "lora", name) for name in lora_names + ] + logger.info("blending base model %s with LoRA models: %s", model, lora_models) + + text_encoder = blend_loras( + server, + text_encoder, + list(zip(lora_models, lora_weights)), + "text_encoder", + 1 if params.is_xl() else None, + params.is_xl(), + ) + (text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder) + text_encoder_names, text_encoder_values = zip(*text_encoder_data) + text_encoder_opts = device.sess_options(cache=False) + text_encoder_opts.add_external_initializers( + list(text_encoder_names), list(text_encoder_values) + ) + + if params.is_xl(): + text_encoder_session = InferenceSession( + text_encoder.SerializeToString(), + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_opts, + ) + text_encoder_session._model_path = path.join(model, "text_encoder") + components["text_encoder_session"] = text_encoder_session + else: + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder.SerializeToString(), + provider=device.ort_provider("text-encoder"), + sess_options=text_encoder_opts, + ) + ) + + if params.is_xl(): + text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) + text_encoder_2 = blend_loras( + server, + text_encoder_2, + list(zip(lora_models, lora_weights)), + "text_encoder", + 2, + params.is_xl(), + ) + (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( + text_encoder_2 + ) + text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) + text_encoder_2_opts = device.sess_options(cache=False) + text_encoder_2_opts.add_external_initializers( + list(text_encoder_2_names), list(text_encoder_2_values) + ) + + text_encoder_2_session = InferenceSession( + text_encoder_2.SerializeToString(), + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_2_opts, + ) + text_encoder_2_session._model_path = path.join(model, "text_encoder_2") + components["text_encoder_2_session"] = text_encoder_2_session + + return components + + +def load_unet(server, device, model, loras, unet_type, params): + components = {} + unet = path.join(model, unet_type, ONNX_MODEL) + + # LoRA blending + if loras is not None and len(loras) > 0: + lora_names, lora_weights = zip(*loras) + lora_models = [ + path.join(server.model_path, "lora", name) for name in lora_names + ] + logger.info("blending base model %s with LoRA models: %s", model, lora_models) + + # blend and load unet + blended_unet = blend_loras( + server, + unet, + list(zip(lora_models, lora_weights)), + "unet", + xl=params.is_xl(), + ) + (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) + unet_names, unet_values = zip(*unet_data) + unet_opts = device.sess_options(cache=False) + unet_opts.add_external_initializers(list(unet_names), list(unet_values)) + + if params.is_xl(): + unet_session = InferenceSession( + unet_model.SerializeToString(), + providers=[device.ort_provider("unet")], + sess_options=unet_opts, + ) + unet_session._model_path = path.join(model, "unet") + components["unet_session"] = unet_session + else: + components["unet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + unet_model.SerializeToString(), + provider=device.ort_provider("unet"), + sess_options=unet_opts, + ) + ) + + # make sure a UNet has been loaded + if not params.is_xl() and "unet" not in components: + unet = path.join(model, unet_type, ONNX_MODEL) + logger.debug("loading UNet (%s) from %s", unet_type, unet) + components["unet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + unet, + provider=device.ort_provider("unet"), + sess_options=device.sess_options(), + ) + ) + + return components + + +def load_vae(server, device, model, params): + # one or more VAE models need to be loaded + vae = path.join(model, "vae", ONNX_MODEL) + vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) + vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL) + + components = {} + if not params.is_xl() and path.exists(vae): + logger.debug("loading VAE from %s", vae) + components["vae"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + vae, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), + ) + ) + elif path.exists(vae_decoder) and path.exists(vae_encoder): + if params.is_xl(): + logger.debug("loading VAE decoder from %s", vae_decoder) + components["vae_decoder_session"] = OnnxRuntimeModel.load_model( + vae_decoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), + ) + components[ + "vae_decoder_session" + ]._model_path = vae_decoder + + logger.debug("loading VAE encoder from %s", vae_encoder) + components["vae_encoder_session"] = OnnxRuntimeModel.load_model( + vae_encoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), + ) + components[ + "vae_encoder_session" + ]._model_path = vae_encoder + + else: + logger.debug("loading VAE decoder from %s", vae_decoder) + components["vae_decoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + vae_decoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), + ) + ) + + logger.debug("loading VAE encoder from %s", vae_encoder) + components["vae_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + vae_encoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), + ) + ) + + return components + + def optimize_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, diff --git a/api/tests/test_diffusers/test_load.py b/api/tests/test_diffusers/test_load.py index a7f4c97c..beaab12c 100644 --- a/api/tests/test_diffusers/test_load.py +++ b/api/tests/test_diffusers/test_load.py @@ -11,7 +11,6 @@ from onnx_web.diffusers.load import ( ) from onnx_web.diffusers.patches.unet import UNetWrapper from onnx_web.diffusers.patches.vae import VAEWrapper -from onnx_web.diffusers.utils import expand_prompt from onnx_web.params import ImageParams from onnx_web.server.context import ServerContext from tests.mocks import MockPipeline diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index f859a141..08f791d9 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -26,6 +26,7 @@ "bokeh", "Civitai", "ckpt", + "cnet", "codebook", "codeformer", "controlnet", @@ -53,6 +54,8 @@ "KDPM", "Knollingcase", "Lanczos", + "loha", + "loras", "Multistep", "ndarray", "numpy",