feat(api): detect Textual Inversion type from keys (#262)
This commit is contained in:
parent
e19e36ae22
commit
0732058aa8
|
@ -6,6 +6,7 @@ from sys import exit
|
|||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from jsonschema import ValidationError, validate
|
||||
from onnx import load_model, save_model
|
||||
from transformers import CLIPTokenizer
|
||||
|
@ -216,17 +217,24 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
logger.info("skipping network: %s", name)
|
||||
else:
|
||||
network_format = source_format(network)
|
||||
network_model = network.get("model", None)
|
||||
network_type = network["type"]
|
||||
source = network["source"]
|
||||
|
||||
try:
|
||||
dest = fetch_model(
|
||||
ctx,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(ctx.model_path, network_type),
|
||||
format=network_format,
|
||||
)
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest = hf_hub_download(
|
||||
repo_id=source, filename="learned_embeds.bin"
|
||||
)
|
||||
else:
|
||||
dest = fetch_model(
|
||||
ctx,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(ctx.model_path, network_type),
|
||||
format=network_format,
|
||||
)
|
||||
|
||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||
except Exception:
|
||||
logger.exception("error fetching network %s", name)
|
||||
|
|
|
@ -13,7 +13,6 @@ from onnx.external_data_helper import (
|
|||
write_external_data_tensors,
|
||||
)
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from ...server.context import ServerContext
|
||||
from ..utils import ConversionContext, load_tensor
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from onnx import ModelProto, load_model, numpy_helper, save_model
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
|
@ -28,12 +27,9 @@ def blend_textual_inversions(
|
|||
|
||||
for name, weight, base_token, inversion_format in inversions:
|
||||
if base_token is None:
|
||||
logger.debug("no base token provided, using name: %s", name)
|
||||
base_token = name
|
||||
|
||||
if inversion_format is None:
|
||||
# TODO: detect concept format
|
||||
inversion_format = "embeddings"
|
||||
|
||||
logger.info(
|
||||
"blending Textual Inversion %s with weight of %s for token %s",
|
||||
name,
|
||||
|
@ -41,23 +37,30 @@ def blend_textual_inversions(
|
|||
base_token,
|
||||
)
|
||||
|
||||
if inversion_format == "concept":
|
||||
# TODO: this should be done in fetch, maybe
|
||||
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
||||
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
with open(token_file, "r") as f:
|
||||
token = f.read()
|
||||
|
||||
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
if inversion_format is None:
|
||||
keys: List[str] = list(loaded_embeds.keys())
|
||||
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
||||
logger.debug("detected Textual Inversion concept: %s", keys)
|
||||
inversion_format = "concept"
|
||||
elif "string_to_token" in keys and "string_to_param" in keys:
|
||||
logger.debug("detected Textual Inversion embeddings: %s", keys)
|
||||
inversion_format = "embeddings"
|
||||
else:
|
||||
logger.error(
|
||||
"unknown Textual Inversion format, no recognized keys: %s", keys
|
||||
)
|
||||
continue
|
||||
|
||||
if inversion_format == "concept":
|
||||
# separate token and the embeds
|
||||
trained_token = list(loaded_embeds.keys())[0]
|
||||
token = list(loaded_embeds.keys())[0]
|
||||
|
||||
layer = loaded_embeds[trained_token].numpy().astype(dtype)
|
||||
layer = loaded_embeds[token].numpy().astype(dtype)
|
||||
layer *= weight
|
||||
|
||||
if base_token in embeds:
|
||||
|
@ -70,17 +73,12 @@ def blend_textual_inversions(
|
|||
else:
|
||||
embeds[token] = layer
|
||||
elif inversion_format == "embeddings":
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
string_to_token = loaded_embeds["string_to_token"]
|
||||
string_to_param = loaded_embeds["string_to_param"]
|
||||
|
||||
# separate token and embeds
|
||||
trained_token = list(string_to_token.keys())[0]
|
||||
trained_embeds = string_to_param[trained_token]
|
||||
token = list(string_to_token.keys())[0]
|
||||
trained_embeds = string_to_param[token]
|
||||
|
||||
num_tokens = trained_embeds.shape[0]
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||
|
|
|
@ -200,9 +200,7 @@ def remove_prefix(name: str, prefix: str) -> str:
|
|||
|
||||
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||
try:
|
||||
logger.debug(
|
||||
"loading tensor with Torch JIT: %s", name
|
||||
)
|
||||
logger.debug("loading tensor with Torch JIT: %s", name)
|
||||
checkpoint = torch.jit.load(name)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
@ -246,7 +244,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
except Exception as e:
|
||||
logger.warning("error loading pickle tensor: %s", e)
|
||||
elif extension in ["onnx", "pt"]:
|
||||
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
|
||||
logger.warning(
|
||||
"tensor has ONNX extension, falling back to PyTorch: %s", extension
|
||||
)
|
||||
try:
|
||||
checkpoint = load_torch(name, map_location=map_location)
|
||||
except Exception as e:
|
||||
|
|
|
@ -226,8 +226,7 @@ def load_pipeline(
|
|||
if loras is not None and len(loras) > 0:
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [
|
||||
path.join(server.model_path, "lora", name)
|
||||
for name in lora_names
|
||||
path.join(server.model_path, "lora", name) for name in lora_names
|
||||
]
|
||||
logger.info(
|
||||
"blending base model %s with LoRA models: %s", model, lora_models
|
||||
|
|
|
@ -22,7 +22,9 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
|||
apply_patches(server)
|
||||
setproctitle("onnx-web worker: %s" % (context.device.device))
|
||||
|
||||
logger.trace("checking in from worker with providers: %s", get_available_providers())
|
||||
logger.trace(
|
||||
"checking in from worker with providers: %s", get_available_providers()
|
||||
)
|
||||
|
||||
# make leaking workers easier to recycle
|
||||
context.progress.cancel_join_thread()
|
||||
|
|
|
@ -108,6 +108,16 @@ $defs:
|
|||
format:
|
||||
type: string
|
||||
enum: [ckpt, safetensors]
|
||||
model:
|
||||
type: string
|
||||
enum: [
|
||||
# inversion
|
||||
concept,
|
||||
embeddings,
|
||||
# lora
|
||||
cloneofsimo,
|
||||
sd-scripts
|
||||
]
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
|
|
Loading…
Reference in New Issue