load inversions from extras file
This commit is contained in:
parent
9dedfc7b28
commit
6b4ced2608
|
@ -217,9 +217,6 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
ctx, name, model["source"], model_format=model_format
|
ctx, name, model["source"], model_format=model_format
|
||||||
)
|
)
|
||||||
|
|
||||||
if "inversion" in model:
|
|
||||||
convert_diffusion_textual_inversion(ctx, source, model["inversion"])
|
|
||||||
|
|
||||||
if model_format in model_formats_original:
|
if model_format in model_formats_original:
|
||||||
convert_diffusion_original(
|
convert_diffusion_original(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -232,6 +229,13 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for inversion in model.get("inversions", []):
|
||||||
|
inversion_name = inversion["name"]
|
||||||
|
inversion_source = inversion["source"]
|
||||||
|
inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source)
|
||||||
|
convert_diffusion_textual_inversion(ctx, inversion_name, source, inversion_source)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("error converting diffusion model %s: %s", name, e)
|
logger.error("error converting diffusion model %s: %s", name, e)
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def convert_loras(part: str):
|
def convert_diffusion_lora(part: str):
|
||||||
lora_weights = [
|
lora_weights = [
|
||||||
f"diffusion-lora-jack/{part}/model.onnx",
|
f"diffusion-lora-jack/{part}/model.onnx",
|
||||||
f"diffusion-lora-taters/{part}/model.onnx",
|
f"diffusion-lora-taters/{part}/model.onnx",
|
||||||
|
@ -90,5 +90,5 @@ def convert_loras(part: str):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert_loras("unet")
|
convert_diffusion_lora("unet")
|
||||||
convert_loras("text_encoder")
|
convert_diffusion_lora("text_encoder")
|
|
@ -2,21 +2,22 @@ from os import mkdir, path
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
from sys import argv
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from ..utils import ConversionContext, sanitize_name
|
from ..utils import ConversionContext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str):
|
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
|
||||||
cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}")
|
cache_path = path.join(context.cache_path, f"inversion-{name}")
|
||||||
logger.info("converting textual inversion: %s -> %s", inversion, cache_path)
|
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path)
|
||||||
|
|
||||||
|
if path.exists(cache_path):
|
||||||
|
logger.info("ONNX model already exists, skipping.")
|
||||||
|
|
||||||
if not path.exists(cache_path):
|
|
||||||
mkdir(cache_path)
|
mkdir(cache_path)
|
||||||
|
|
||||||
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
||||||
|
@ -82,7 +83,3 @@ def convert_diffusion_textual_inversion(context: ConversionContext, base_model:
|
||||||
do_constant_folding=True,
|
do_constant_folding=True,
|
||||||
opset_version=context.opset,
|
opset_version=context.opset,
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
context = ConversionContext.from_environ()
|
|
||||||
convert_diffusion_textual_inversion(context, argv[1], argv[2])
|
|
||||||
|
|
|
@ -10,6 +10,15 @@ $defs:
|
||||||
- type: number
|
- type: number
|
||||||
- type: string
|
- type: string
|
||||||
|
|
||||||
|
textual_inversion:
|
||||||
|
type: object
|
||||||
|
required: [name, source]
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
source:
|
||||||
|
type: string
|
||||||
|
|
||||||
base_model:
|
base_model:
|
||||||
type: object
|
type: object
|
||||||
required: [name, source]
|
required: [name, source]
|
||||||
|
@ -37,6 +46,10 @@ $defs:
|
||||||
properties:
|
properties:
|
||||||
config:
|
config:
|
||||||
type: string
|
type: string
|
||||||
|
inversions:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/$defs/textual_inversion"
|
||||||
vae:
|
vae:
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue