1
0
Fork 0

feat(api): detect Textual Inversion type from keys (#262)

This commit is contained in:
Sean Sube 2023-03-19 20:16:52 -05:00
parent e19e36ae22
commit 0732058aa8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 55 additions and 39 deletions

View File

@ -6,6 +6,7 @@ from sys import exit
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from huggingface_hub.file_download import hf_hub_download
from jsonschema import ValidationError, validate
from onnx import load_model, save_model
from transformers import CLIPTokenizer
@ -216,10 +217,16 @@ def convert_models(ctx: ConversionContext, args, models: Models):
logger.info("skipping network: %s", name)
else:
network_format = source_format(network)
network_model = network.get("model", None)
network_type = network["type"]
source = network["source"]
try:
if network_type == "inversion" and network_model == "concept":
dest = hf_hub_download(
repo_id=source, filename="learned_embeds.bin"
)
else:
dest = fetch_model(
ctx,
name,
@ -227,6 +234,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
dest=path.join(ctx.model_path, network_type),
format=network_format,
)
logger.info("finished downloading network: %s -> %s", source, dest)
except Exception:
logger.exception("error fetching network %s", name)

View File

@ -13,7 +13,6 @@ from onnx.external_data_helper import (
write_external_data_tensors,
)
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file
from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
from huggingface_hub.file_download import hf_hub_download
from onnx import ModelProto, load_model, numpy_helper, save_model
from transformers import CLIPTokenizer
@ -28,12 +27,9 @@ def blend_textual_inversions(
for name, weight, base_token, inversion_format in inversions:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
if inversion_format is None:
# TODO: detect concept format
inversion_format = "embeddings"
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
@ -41,23 +37,30 @@ def blend_textual_inversions(
base_token,
)
if inversion_format == "concept":
# TODO: this should be done in fetch, maybe
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed
with open(token_file, "r") as f:
token = f.read()
loaded_embeds = load_tensor(embeds_file, map_location=device)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
if inversion_format is None:
keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
inversion_format = "concept"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion embeddings: %s", keys)
inversion_format = "embeddings"
else:
logger.error(
"unknown Textual Inversion format, no recognized keys: %s", keys
)
continue
layer = loaded_embeds[trained_token].numpy().astype(dtype)
if inversion_format == "concept":
# separate token and the embeds
token = list(loaded_embeds.keys())[0]
layer = loaded_embeds[token].numpy().astype(dtype)
layer *= weight
if base_token in embeds:
@ -70,17 +73,12 @@ def blend_textual_inversions(
else:
embeds[token] = layer
elif inversion_format == "embeddings":
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
# separate token and embeds
trained_token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[trained_token]
token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)

View File

@ -200,9 +200,7 @@ def remove_prefix(name: str, prefix: str) -> str:
def load_torch(name: str, map_location=None) -> Optional[Dict]:
try:
logger.debug(
"loading tensor with Torch JIT: %s", name
)
logger.debug("loading tensor with Torch JIT: %s", name)
checkpoint = torch.jit.load(name)
except Exception:
logger.exception(
@ -246,7 +244,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
except Exception as e:
logger.warning("error loading pickle tensor: %s", e)
elif extension in ["onnx", "pt"]:
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
logger.warning(
"tensor has ONNX extension, falling back to PyTorch: %s", extension
)
try:
checkpoint = load_torch(name, map_location=map_location)
except Exception as e:

View File

@ -226,8 +226,7 @@ def load_pipeline(
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name)
for name in lora_names
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info(
"blending base model %s with LoRA models: %s", model, lora_models

View File

@ -22,7 +22,9 @@ def worker_main(context: WorkerContext, server: ServerContext):
apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device))
logger.trace("checking in from worker with providers: %s", get_available_providers())
logger.trace(
"checking in from worker with providers: %s", get_available_providers()
)
# make leaking workers easier to recycle
context.progress.cancel_join_thread()

View File

@ -108,6 +108,16 @@ $defs:
format:
type: string
enum: [ckpt, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name:
type: string
source: