1
0
Fork 0

feat(api): blend LoRAs and Textual Inversions from extras file

This commit is contained in:
Sean Sube 2023-03-18 07:01:16 -05:00
parent 1d44f985a4
commit 84bd852837
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 71 additions and 11 deletions

View File

@ -7,12 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from jsonschema import ValidationError, validate
from onnx import load_model, save_model
from transformers import CLIPTokenizer
from yaml import safe_load
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.lora import blend_loras
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
from .diffusion.textual_inversion import blend_textual_inversions
from .upscale_resrgan import convert_upscale_resrgan
from .utils import (
ConversionContext,
@ -229,22 +232,77 @@ def convert_models(ctx: ConversionContext, args, models: Models):
source,
)
# keep track of which models have been blended
blend_models = {}
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
if "tokenizer" not in blend_models:
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer")
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", "embeddings")
inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source
)
convert_diffusion_textual_inversion(
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
blend_textual_inversions(
ctx,
inversion_name,
model["source"],
inversion_source,
inversion_format,
base_token=inversion.get("token"),
blend_models["text_encoder"],
blend_models["tokenizer"],
[inversion_source],
[inversion_format],
base_token=inversion_token,
inversion_weights=[inversion_weight],
)
for lora in model.get("loras", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
if "unet" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx"))
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source = fetch_model(
ctx, f"{name}-lora-{lora_name}", lora_source
)
lora_weight = lora.get("weight", 1.0)
blend_loras(
ctx,
blend_models["text_encoder"],
[lora_name],
[lora_source],
"text_encoder",
lora_weights=[lora_weight],
)
if "tokenizer" in blend_models:
dest_path = path.join(ctx.model_path, model, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)
for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(ctx.model_path, model, name, "model.onnx")
logger.debug("saving blended %s model to %s", name, dest_path)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
)
except Exception:
logger.exception(
"error converting diffusion model %s",

View File

@ -130,6 +130,7 @@ def convert_diffusion_textual_inversion(
inversion: str,
format: str,
base_token: Optional[str] = None,
weight: Optional[float] = 1.0,
):
dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info(
@ -161,7 +162,7 @@ def convert_diffusion_textual_inversion(
tokenizer,
[inversion],
[format],
[1.0],
[weight],
base_token=(base_token or name),
)

View File

@ -326,7 +326,8 @@ You can blend extra networks with the diffusion model using `<type:name:weight>`
#### LoRA tokens
You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a `lora` token:
You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a
`lora` token:
```none
<lora:name:0.5>
@ -341,8 +342,8 @@ contain any special characters.
#### Textual Inversion tokens
You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model using the `inversion`
token:
You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model
using the `inversion` token:
```none
<inversion:autumn:1.0>