load inversions from extras file
This commit is contained in:
parent
9dedfc7b28
commit
6b4ced2608
|
@ -217,9 +217,6 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
ctx, name, model["source"], model_format=model_format
|
||||
)
|
||||
|
||||
if "inversion" in model:
|
||||
convert_diffusion_textual_inversion(ctx, source, model["inversion"])
|
||||
|
||||
if model_format in model_formats_original:
|
||||
convert_diffusion_original(
|
||||
ctx,
|
||||
|
@ -232,6 +229,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
model,
|
||||
source,
|
||||
)
|
||||
|
||||
for inversion in model.get("inversions", []):
|
||||
inversion_name = inversion["name"]
|
||||
inversion_source = inversion["source"]
|
||||
inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source)
|
||||
convert_diffusion_textual_inversion(ctx, inversion_name, source, inversion_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("error converting diffusion model %s: %s", name, e)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]
|
|||
return results
|
||||
|
||||
|
||||
def convert_loras(part: str):
|
||||
def convert_diffusion_lora(part: str):
|
||||
lora_weights = [
|
||||
f"diffusion-lora-jack/{part}/model.onnx",
|
||||
f"diffusion-lora-taters/{part}/model.onnx",
|
||||
|
@ -90,5 +90,5 @@ def convert_loras(part: str):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_loras("unet")
|
||||
convert_loras("text_encoder")
|
||||
convert_diffusion_lora("unet")
|
||||
convert_diffusion_lora("text_encoder")
|
|
@ -2,21 +2,22 @@ from os import mkdir, path
|
|||
from huggingface_hub.file_download import hf_hub_download
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from torch.onnx import export
|
||||
from sys import argv
|
||||
from logging import getLogger
|
||||
|
||||
from ..utils import ConversionContext, sanitize_name
|
||||
from ..utils import ConversionContext
|
||||
|
||||
import torch
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str):
|
||||
cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}")
|
||||
logger.info("converting textual inversion: %s -> %s", inversion, cache_path)
|
||||
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
|
||||
cache_path = path.join(context.cache_path, f"inversion-{name}")
|
||||
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path)
|
||||
|
||||
if path.exists(cache_path):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
|
||||
if not path.exists(cache_path):
|
||||
mkdir(cache_path)
|
||||
|
||||
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
||||
|
@ -82,7 +83,3 @@ def convert_diffusion_textual_inversion(context: ConversionContext, base_model:
|
|||
do_constant_folding=True,
|
||||
opset_version=context.opset,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
context = ConversionContext.from_environ()
|
||||
convert_diffusion_textual_inversion(context, argv[1], argv[2])
|
||||
|
|
|
@ -10,6 +10,15 @@ $defs:
|
|||
- type: number
|
||||
- type: string
|
||||
|
||||
textual_inversion:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
|
||||
base_model:
|
||||
type: object
|
||||
required: [name, source]
|
||||
|
@ -37,6 +46,10 @@ $defs:
|
|||
properties:
|
||||
config:
|
||||
type: string
|
||||
inversions:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/textual_inversion"
|
||||
vae:
|
||||
type: string
|
||||
|
||||
|
|
Loading…
Reference in New Issue