From 6b4ced26080a29ee48bfc8b82edc22d7fab8c5a3 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 21 Feb 2023 21:40:57 -0600 Subject: [PATCH] load inversions from extras file --- api/onnx_web/convert/__main__.py | 10 +++++++--- api/onnx_web/convert/diffusion/lora.py | 6 +++--- .../convert/diffusion/textual_inversion.py | 19 ++++++++----------- api/schemas/extras.yaml | 13 +++++++++++++ 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index eab146fb..b1869fd2 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -217,9 +217,6 @@ def convert_models(ctx: ConversionContext, args, models: Models): ctx, name, model["source"], model_format=model_format ) - if "inversion" in model: - convert_diffusion_textual_inversion(ctx, source, model["inversion"]) - if model_format in model_formats_original: convert_diffusion_original( ctx, @@ -232,6 +229,13 @@ def convert_models(ctx: ConversionContext, args, models: Models): model, source, ) + + for inversion in model.get("inversions", []): + inversion_name = inversion["name"] + inversion_source = inversion["source"] + inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source) + convert_diffusion_textual_inversion(ctx, inversion_name, source, inversion_source) + except Exception as e: logger.error("error converting diffusion model %s: %s", name, e) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 1e4d24c2..1b97b430 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -39,7 +39,7 @@ def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float] return results -def convert_loras(part: str): +def convert_diffusion_lora(part: str): lora_weights = [ f"diffusion-lora-jack/{part}/model.onnx", f"diffusion-lora-taters/{part}/model.onnx", @@ -90,5 +90,5 @@ def convert_loras(part: str): if __name__ == "__main__": - convert_loras("unet") - convert_loras("text_encoder") \ No newline at end of file + convert_diffusion_lora("unet") + convert_diffusion_lora("text_encoder") \ No newline at end of file diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 24f5a1bc..d125f4b2 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -2,22 +2,23 @@ from os import mkdir, path from huggingface_hub.file_download import hf_hub_download from transformers import CLIPTokenizer, CLIPTextModel from torch.onnx import export -from sys import argv from logging import getLogger -from ..utils import ConversionContext, sanitize_name +from ..utils import ConversionContext import torch logger = getLogger(__name__) -def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str): - cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}") - logger.info("converting textual inversion: %s -> %s", inversion, cache_path) +def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str): + cache_path = path.join(context.cache_path, f"inversion-{name}") + logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path) - if not path.exists(cache_path): - mkdir(cache_path) + if path.exists(cache_path): + logger.info("ONNX model already exists, skipping.") + + mkdir(cache_path) embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") @@ -82,7 +83,3 @@ def convert_diffusion_textual_inversion(context: ConversionContext, base_model: do_constant_folding=True, opset_version=context.opset, ) - -if __name__ == "__main__": - context = ConversionContext.from_environ() - convert_diffusion_textual_inversion(context, argv[1], argv[2]) diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index b0a0f912..2cf47123 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -10,6 +10,15 @@ $defs: - type: number - type: string + textual_inversion: + type: object + required: [name, source] + properties: + name: + type: string + source: + type: string + base_model: type: object required: [name, source] @@ -37,6 +46,10 @@ $defs: properties: config: type: string + inversions: + type: array + items: + $ref: "#/$defs/textual_inversion" vae: type: string