diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 697a3b29..ad55051c 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -7,12 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse from jsonschema import ValidationError, validate +from onnx import load_model, save_model +from transformers import CLIPTokenizer from yaml import safe_load from .correction_gfpgan import convert_correction_gfpgan from .diffusion.diffusers import convert_diffusion_diffusers +from .diffusion.lora import blend_loras from .diffusion.original import convert_diffusion_original -from .diffusion.textual_inversion import convert_diffusion_textual_inversion +from .diffusion.textual_inversion import blend_textual_inversions from .upscale_resrgan import convert_upscale_resrgan from .utils import ( ConversionContext, @@ -229,22 +232,77 @@ def convert_models(ctx: ConversionContext, args, models: Models): source, ) + # keep track of which models have been blended + blend_models = {} + for inversion in model.get("inversions", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx")) + + if "tokenizer" not in blend_models: + blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer") + inversion_name = inversion["name"] inversion_source = inversion["source"] inversion_format = inversion.get("format", "embeddings") inversion_source = fetch_model( ctx, f"{name}-inversion-{inversion_name}", inversion_source ) - convert_diffusion_textual_inversion( + inversion_token = inversion.get("token", inversion_name) + inversion_weight = inversion.get("weight", 1.0) + + blend_textual_inversions( ctx, - inversion_name, - model["source"], - inversion_source, - inversion_format, - base_token=inversion.get("token"), + blend_models["text_encoder"], + blend_models["tokenizer"], + [inversion_source], + [inversion_format], + base_token=inversion_token, + inversion_weights=[inversion_weight], ) + for lora in model.get("loras", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx")) + + if "unet" not in blend_models: + blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx")) + + # load models if not loaded yet + lora_name = lora["name"] + lora_source = lora["source"] + lora_source = fetch_model( + ctx, f"{name}-lora-{lora_name}", lora_source + ) + lora_weight = lora.get("weight", 1.0) + + blend_loras( + ctx, + blend_models["text_encoder"], + [lora_name], + [lora_source], + "text_encoder", + lora_weights=[lora_weight], + ) + + if "tokenizer" in blend_models: + dest_path = path.join(ctx.model_path, model, "tokenizer") + logger.debug("saving blended tokenizer to %s", dest_path) + blend_models["tokenizer"].save_pretrained(dest_path) + + for name in ["text_encoder", "unet"]: + if name in blend_models: + dest_path = path.join(ctx.model_path, model, name, "model.onnx") + logger.debug("saving blended %s model to %s", name, dest_path) + save_model( + blend_models[name], + dest_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + ) + + except Exception: logger.exception( "error converting diffusion model %s", diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index dd93d018..5aa4b715 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -130,6 +130,7 @@ def convert_diffusion_textual_inversion( inversion: str, format: str, base_token: Optional[str] = None, + weight: Optional[float] = 1.0, ): dest_path = path.join(context.model_path, f"inversion-{name}") logger.info( @@ -161,7 +162,7 @@ def convert_diffusion_textual_inversion( tokenizer, [inversion], [format], - [1.0], + [weight], base_token=(base_token or name), ) diff --git a/docs/user-guide.md b/docs/user-guide.md index b9b05d8e..1896052b 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -326,7 +326,8 @@ You can blend extra networks with the diffusion model using `` #### LoRA tokens -You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a `lora` token: +You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a +`lora` token: ```none @@ -341,8 +342,8 @@ contain any special characters. #### Textual Inversion tokens -You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model using the `inversion` -token: +You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model +using the `inversion` token: ```none