1
0
Fork 0

load inversions from extras file

This commit is contained in:
Sean Sube 2023-02-21 21:40:57 -06:00
parent 9dedfc7b28
commit 6b4ced2608
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 31 additions and 17 deletions

View File

@ -217,9 +217,6 @@ def convert_models(ctx: ConversionContext, args, models: Models):
ctx, name, model["source"], model_format=model_format ctx, name, model["source"], model_format=model_format
) )
if "inversion" in model:
convert_diffusion_textual_inversion(ctx, source, model["inversion"])
if model_format in model_formats_original: if model_format in model_formats_original:
convert_diffusion_original( convert_diffusion_original(
ctx, ctx,
@ -232,6 +229,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
model, model,
source, source,
) )
for inversion in model.get("inversions", []):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source)
convert_diffusion_textual_inversion(ctx, inversion_name, source, inversion_source)
except Exception as e: except Exception as e:
logger.error("error converting diffusion model %s: %s", name, e) logger.error("error converting diffusion model %s: %s", name, e)

View File

@ -39,7 +39,7 @@ def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]
return results return results
def convert_loras(part: str): def convert_diffusion_lora(part: str):
lora_weights = [ lora_weights = [
f"diffusion-lora-jack/{part}/model.onnx", f"diffusion-lora-jack/{part}/model.onnx",
f"diffusion-lora-taters/{part}/model.onnx", f"diffusion-lora-taters/{part}/model.onnx",
@ -90,5 +90,5 @@ def convert_loras(part: str):
if __name__ == "__main__": if __name__ == "__main__":
convert_loras("unet") convert_diffusion_lora("unet")
convert_loras("text_encoder") convert_diffusion_lora("text_encoder")

View File

@ -2,22 +2,23 @@ from os import mkdir, path
from huggingface_hub.file_download import hf_hub_download from huggingface_hub.file_download import hf_hub_download
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
from torch.onnx import export from torch.onnx import export
from sys import argv
from logging import getLogger from logging import getLogger
from ..utils import ConversionContext, sanitize_name from ..utils import ConversionContext
import torch import torch
logger = getLogger(__name__) logger = getLogger(__name__)
def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str): def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}") cache_path = path.join(context.cache_path, f"inversion-{name}")
logger.info("converting textual inversion: %s -> %s", inversion, cache_path) logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path)
if not path.exists(cache_path): if path.exists(cache_path):
mkdir(cache_path) logger.info("ONNX model already exists, skipping.")
mkdir(cache_path)
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
@ -82,7 +83,3 @@ def convert_diffusion_textual_inversion(context: ConversionContext, base_model:
do_constant_folding=True, do_constant_folding=True,
opset_version=context.opset, opset_version=context.opset,
) )
if __name__ == "__main__":
context = ConversionContext.from_environ()
convert_diffusion_textual_inversion(context, argv[1], argv[2])

View File

@ -10,6 +10,15 @@ $defs:
- type: number - type: number
- type: string - type: string
textual_inversion:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
base_model: base_model:
type: object type: object
required: [name, source] required: [name, source]
@ -37,6 +46,10 @@ $defs:
properties: properties:
config: config:
type: string type: string
inversions:
type: array
items:
$ref: "#/$defs/textual_inversion"
vae: vae:
type: string type: string