1
0
Fork 0

load inversions from extras file

This commit is contained in:
Sean Sube 2023-02-21 21:40:57 -06:00
parent 9dedfc7b28
commit 6b4ced2608
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 31 additions and 17 deletions

View File

@ -217,9 +217,6 @@ def convert_models(ctx: ConversionContext, args, models: Models):
ctx, name, model["source"], model_format=model_format
)
if "inversion" in model:
convert_diffusion_textual_inversion(ctx, source, model["inversion"])
if model_format in model_formats_original:
convert_diffusion_original(
ctx,
@ -232,6 +229,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
model,
source,
)
for inversion in model.get("inversions", []):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source)
convert_diffusion_textual_inversion(ctx, inversion_name, source, inversion_source)
except Exception as e:
logger.error("error converting diffusion model %s: %s", name, e)

View File

@ -39,7 +39,7 @@ def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]
return results
def convert_loras(part: str):
def convert_diffusion_lora(part: str):
lora_weights = [
f"diffusion-lora-jack/{part}/model.onnx",
f"diffusion-lora-taters/{part}/model.onnx",
@ -90,5 +90,5 @@ def convert_loras(part: str):
if __name__ == "__main__":
convert_loras("unet")
convert_loras("text_encoder")
convert_diffusion_lora("unet")
convert_diffusion_lora("text_encoder")

View File

@ -2,21 +2,22 @@ from os import mkdir, path
from huggingface_hub.file_download import hf_hub_download
from transformers import CLIPTokenizer, CLIPTextModel
from torch.onnx import export
from sys import argv
from logging import getLogger
from ..utils import ConversionContext, sanitize_name
from ..utils import ConversionContext
import torch
logger = getLogger(__name__)
def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str):
cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}")
logger.info("converting textual inversion: %s -> %s", inversion, cache_path)
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
cache_path = path.join(context.cache_path, f"inversion-{name}")
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path)
if path.exists(cache_path):
logger.info("ONNX model already exists, skipping.")
if not path.exists(cache_path):
mkdir(cache_path)
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
@ -82,7 +83,3 @@ def convert_diffusion_textual_inversion(context: ConversionContext, base_model:
do_constant_folding=True,
opset_version=context.opset,
)
if __name__ == "__main__":
context = ConversionContext.from_environ()
convert_diffusion_textual_inversion(context, argv[1], argv[2])

View File

@ -10,6 +10,15 @@ $defs:
- type: number
- type: string
textual_inversion:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
base_model:
type: object
required: [name, source]
@ -37,6 +46,10 @@ $defs:
properties:
config:
type: string
inversions:
type: array
items:
$ref: "#/$defs/textual_inversion"
vae:
type: string