feat(api): blend LoRAs and Textual Inversions from extras file
This commit is contained in:
parent
1d44f985a4
commit
84bd852837
|
@ -7,12 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from jsonschema import ValidationError, validate
|
||||
from onnx import load_model, save_model
|
||||
from transformers import CLIPTokenizer
|
||||
from yaml import safe_load
|
||||
|
||||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||
from .diffusion.lora import blend_loras
|
||||
from .diffusion.original import convert_diffusion_original
|
||||
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
|
||||
from .diffusion.textual_inversion import blend_textual_inversions
|
||||
from .upscale_resrgan import convert_upscale_resrgan
|
||||
from .utils import (
|
||||
ConversionContext,
|
||||
|
@ -229,22 +232,77 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
source,
|
||||
)
|
||||
|
||||
# keep track of which models have been blended
|
||||
blend_models = {}
|
||||
|
||||
for inversion in model.get("inversions", []):
|
||||
if "text_encoder" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
||||
|
||||
if "tokenizer" not in blend_models:
|
||||
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer")
|
||||
|
||||
inversion_name = inversion["name"]
|
||||
inversion_source = inversion["source"]
|
||||
inversion_format = inversion.get("format", "embeddings")
|
||||
inversion_source = fetch_model(
|
||||
ctx, f"{name}-inversion-{inversion_name}", inversion_source
|
||||
)
|
||||
convert_diffusion_textual_inversion(
|
||||
inversion_token = inversion.get("token", inversion_name)
|
||||
inversion_weight = inversion.get("weight", 1.0)
|
||||
|
||||
blend_textual_inversions(
|
||||
ctx,
|
||||
inversion_name,
|
||||
model["source"],
|
||||
inversion_source,
|
||||
inversion_format,
|
||||
base_token=inversion.get("token"),
|
||||
blend_models["text_encoder"],
|
||||
blend_models["tokenizer"],
|
||||
[inversion_source],
|
||||
[inversion_format],
|
||||
base_token=inversion_token,
|
||||
inversion_weights=[inversion_weight],
|
||||
)
|
||||
|
||||
for lora in model.get("loras", []):
|
||||
if "text_encoder" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
||||
|
||||
if "unet" not in blend_models:
|
||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx"))
|
||||
|
||||
# load models if not loaded yet
|
||||
lora_name = lora["name"]
|
||||
lora_source = lora["source"]
|
||||
lora_source = fetch_model(
|
||||
ctx, f"{name}-lora-{lora_name}", lora_source
|
||||
)
|
||||
lora_weight = lora.get("weight", 1.0)
|
||||
|
||||
blend_loras(
|
||||
ctx,
|
||||
blend_models["text_encoder"],
|
||||
[lora_name],
|
||||
[lora_source],
|
||||
"text_encoder",
|
||||
lora_weights=[lora_weight],
|
||||
)
|
||||
|
||||
if "tokenizer" in blend_models:
|
||||
dest_path = path.join(ctx.model_path, model, "tokenizer")
|
||||
logger.debug("saving blended tokenizer to %s", dest_path)
|
||||
blend_models["tokenizer"].save_pretrained(dest_path)
|
||||
|
||||
for name in ["text_encoder", "unet"]:
|
||||
if name in blend_models:
|
||||
dest_path = path.join(ctx.model_path, model, name, "model.onnx")
|
||||
logger.debug("saving blended %s model to %s", name, dest_path)
|
||||
save_model(
|
||||
blend_models[name],
|
||||
dest_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
)
|
||||
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting diffusion model %s",
|
||||
|
|
|
@ -130,6 +130,7 @@ def convert_diffusion_textual_inversion(
|
|||
inversion: str,
|
||||
format: str,
|
||||
base_token: Optional[str] = None,
|
||||
weight: Optional[float] = 1.0,
|
||||
):
|
||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||
logger.info(
|
||||
|
@ -161,7 +162,7 @@ def convert_diffusion_textual_inversion(
|
|||
tokenizer,
|
||||
[inversion],
|
||||
[format],
|
||||
[1.0],
|
||||
[weight],
|
||||
base_token=(base_token or name),
|
||||
)
|
||||
|
||||
|
|
|
@ -326,7 +326,8 @@ You can blend extra networks with the diffusion model using `<type:name:weight>`
|
|||
|
||||
#### LoRA tokens
|
||||
|
||||
You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a `lora` token:
|
||||
You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a
|
||||
`lora` token:
|
||||
|
||||
```none
|
||||
<lora:name:0.5>
|
||||
|
@ -341,8 +342,8 @@ contain any special characters.
|
|||
|
||||
#### Textual Inversion tokens
|
||||
|
||||
You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model using the `inversion`
|
||||
token:
|
||||
You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model
|
||||
using the `inversion` token:
|
||||
|
||||
```none
|
||||
<inversion:autumn:1.0>
|
||||
|
|
Loading…
Reference in New Issue