From 0732058aa89faeedac192e23989006319336d9d6 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 20:16:52 -0500 Subject: [PATCH] feat(api): detect Textual Inversion type from keys (#262) --- api/onnx_web/convert/__main__.py | 22 ++++++--- api/onnx_web/convert/diffusion/lora.py | 1 - .../convert/diffusion/textual_inversion.py | 46 +++++++++---------- api/onnx_web/convert/utils.py | 8 ++-- api/onnx_web/diffusers/load.py | 3 +- api/onnx_web/worker/worker.py | 4 +- api/schemas/extras.yaml | 10 ++++ 7 files changed, 55 insertions(+), 39 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 0b81b20d..5ec93076 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -6,6 +6,7 @@ from sys import exit from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse +from huggingface_hub.file_download import hf_hub_download from jsonschema import ValidationError, validate from onnx import load_model, save_model from transformers import CLIPTokenizer @@ -216,17 +217,24 @@ def convert_models(ctx: ConversionContext, args, models: Models): logger.info("skipping network: %s", name) else: network_format = source_format(network) + network_model = network.get("model", None) network_type = network["type"] source = network["source"] try: - dest = fetch_model( - ctx, - name, - source, - dest=path.join(ctx.model_path, network_type), - format=network_format, - ) + if network_type == "inversion" and network_model == "concept": + dest = hf_hub_download( + repo_id=source, filename="learned_embeds.bin" + ) + else: + dest = fetch_model( + ctx, + name, + source, + dest=path.join(ctx.model_path, network_type), + format=network_format, + ) + logger.info("finished downloading network: %s -> %s", source, dest) except Exception: logger.exception("error fetching network %s", name) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index d75b1862..4a8cd3db 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -13,7 +13,6 @@ from onnx.external_data_helper import ( write_external_data_tensors, ) from onnxruntime import InferenceSession, OrtValue, SessionOptions -from safetensors.torch import load_file from ...server.context import ServerContext from ..utils import ConversionContext, load_tensor diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index da0179d4..3650ee55 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -4,7 +4,6 @@ from typing import List, Optional, Tuple import numpy as np import torch -from huggingface_hub.file_download import hf_hub_download from onnx import ModelProto, load_model, numpy_helper, save_model from transformers import CLIPTokenizer @@ -28,12 +27,9 @@ def blend_textual_inversions( for name, weight, base_token, inversion_format in inversions: if base_token is None: + logger.debug("no base token provided, using name: %s", name) base_token = name - if inversion_format is None: - # TODO: detect concept format - inversion_format = "embeddings" - logger.info( "blending Textual Inversion %s with weight of %s for token %s", name, @@ -41,23 +37,30 @@ def blend_textual_inversions( base_token, ) - if inversion_format == "concept": - # TODO: this should be done in fetch, maybe - embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") - token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed + loaded_embeds = load_tensor(name, map_location=device) + if loaded_embeds is None: + logger.warning("unable to load tensor") + continue - with open(token_file, "r") as f: - token = f.read() - - loaded_embeds = load_tensor(embeds_file, map_location=device) - if loaded_embeds is None: - logger.warning("unable to load tensor") + if inversion_format is None: + keys: List[str] = list(loaded_embeds.keys()) + if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"): + logger.debug("detected Textual Inversion concept: %s", keys) + inversion_format = "concept" + elif "string_to_token" in keys and "string_to_param" in keys: + logger.debug("detected Textual Inversion embeddings: %s", keys) + inversion_format = "embeddings" + else: + logger.error( + "unknown Textual Inversion format, no recognized keys: %s", keys + ) continue + if inversion_format == "concept": # separate token and the embeds - trained_token = list(loaded_embeds.keys())[0] + token = list(loaded_embeds.keys())[0] - layer = loaded_embeds[trained_token].numpy().astype(dtype) + layer = loaded_embeds[token].numpy().astype(dtype) layer *= weight if base_token in embeds: @@ -70,17 +73,12 @@ def blend_textual_inversions( else: embeds[token] = layer elif inversion_format == "embeddings": - loaded_embeds = load_tensor(name, map_location=device) - if loaded_embeds is None: - logger.warning("unable to load tensor") - continue - string_to_token = loaded_embeds["string_to_token"] string_to_param = loaded_embeds["string_to_param"] # separate token and embeds - trained_token = list(string_to_token.keys())[0] - trained_embeds = string_to_param[trained_token] + token = list(string_to_token.keys())[0] + trained_embeds = string_to_param[token] num_tokens = trained_embeds.shape[0] logger.debug("generating %s layer tokens for %s", num_tokens, name) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 6e482d63..290aa513 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -200,9 +200,7 @@ def remove_prefix(name: str, prefix: str) -> str: def load_torch(name: str, map_location=None) -> Optional[Dict]: try: - logger.debug( - "loading tensor with Torch JIT: %s", name - ) + logger.debug("loading tensor with Torch JIT: %s", name) checkpoint = torch.jit.load(name) except Exception: logger.exception( @@ -246,7 +244,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]: except Exception as e: logger.warning("error loading pickle tensor: %s", e) elif extension in ["onnx", "pt"]: - logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension) + logger.warning( + "tensor has ONNX extension, falling back to PyTorch: %s", extension + ) try: checkpoint = load_torch(name, map_location=map_location) except Exception as e: diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 474d300a..08f3cbce 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -226,8 +226,7 @@ def load_pipeline( if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) lora_models = [ - path.join(server.model_path, "lora", name) - for name in lora_names + path.join(server.model_path, "lora", name) for name in lora_names ] logger.info( "blending base model %s with LoRA models: %s", model, lora_models diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index e7bb2671..e30c52c3 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -22,7 +22,9 @@ def worker_main(context: WorkerContext, server: ServerContext): apply_patches(server) setproctitle("onnx-web worker: %s" % (context.device.device)) - logger.trace("checking in from worker with providers: %s", get_available_providers()) + logger.trace( + "checking in from worker with providers: %s", get_available_providers() + ) # make leaking workers easier to recycle context.progress.cancel_join_thread() diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index c1d39b6e..0ad7a3d5 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -108,6 +108,16 @@ $defs: format: type: string enum: [ckpt, safetensors] + model: + type: string + enum: [ + # inversion + concept, + embeddings, + # lora + cloneofsimo, + sd-scripts + ] name: type: string source: