feat(api): detect Textual Inversion type from keys (#262)
This commit is contained in:
parent
e19e36ae22
commit
0732058aa8
|
@ -6,6 +6,7 @@ from sys import exit
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
from jsonschema import ValidationError, validate
|
from jsonschema import ValidationError, validate
|
||||||
from onnx import load_model, save_model
|
from onnx import load_model, save_model
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
@ -216,17 +217,24 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
logger.info("skipping network: %s", name)
|
logger.info("skipping network: %s", name)
|
||||||
else:
|
else:
|
||||||
network_format = source_format(network)
|
network_format = source_format(network)
|
||||||
|
network_model = network.get("model", None)
|
||||||
network_type = network["type"]
|
network_type = network["type"]
|
||||||
source = network["source"]
|
source = network["source"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dest = fetch_model(
|
if network_type == "inversion" and network_model == "concept":
|
||||||
ctx,
|
dest = hf_hub_download(
|
||||||
name,
|
repo_id=source, filename="learned_embeds.bin"
|
||||||
source,
|
)
|
||||||
dest=path.join(ctx.model_path, network_type),
|
else:
|
||||||
format=network_format,
|
dest = fetch_model(
|
||||||
)
|
ctx,
|
||||||
|
name,
|
||||||
|
source,
|
||||||
|
dest=path.join(ctx.model_path, network_type),
|
||||||
|
format=network_format,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error fetching network %s", name)
|
logger.exception("error fetching network %s", name)
|
||||||
|
|
|
@ -13,7 +13,6 @@ from onnx.external_data_helper import (
|
||||||
write_external_data_tensors,
|
write_external_data_tensors,
|
||||||
)
|
)
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||||
from safetensors.torch import load_file
|
|
||||||
|
|
||||||
from ...server.context import ServerContext
|
from ...server.context import ServerContext
|
||||||
from ..utils import ConversionContext, load_tensor
|
from ..utils import ConversionContext, load_tensor
|
||||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
|
||||||
from onnx import ModelProto, load_model, numpy_helper, save_model
|
from onnx import ModelProto, load_model, numpy_helper, save_model
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
|
@ -28,12 +27,9 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
for name, weight, base_token, inversion_format in inversions:
|
for name, weight, base_token, inversion_format in inversions:
|
||||||
if base_token is None:
|
if base_token is None:
|
||||||
|
logger.debug("no base token provided, using name: %s", name)
|
||||||
base_token = name
|
base_token = name
|
||||||
|
|
||||||
if inversion_format is None:
|
|
||||||
# TODO: detect concept format
|
|
||||||
inversion_format = "embeddings"
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"blending Textual Inversion %s with weight of %s for token %s",
|
"blending Textual Inversion %s with weight of %s for token %s",
|
||||||
name,
|
name,
|
||||||
|
@ -41,23 +37,30 @@ def blend_textual_inversions(
|
||||||
base_token,
|
base_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if inversion_format == "concept":
|
loaded_embeds = load_tensor(name, map_location=device)
|
||||||
# TODO: this should be done in fetch, maybe
|
if loaded_embeds is None:
|
||||||
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
logger.warning("unable to load tensor")
|
||||||
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed
|
continue
|
||||||
|
|
||||||
with open(token_file, "r") as f:
|
if inversion_format is None:
|
||||||
token = f.read()
|
keys: List[str] = list(loaded_embeds.keys())
|
||||||
|
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
||||||
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
logger.debug("detected Textual Inversion concept: %s", keys)
|
||||||
if loaded_embeds is None:
|
inversion_format = "concept"
|
||||||
logger.warning("unable to load tensor")
|
elif "string_to_token" in keys and "string_to_param" in keys:
|
||||||
|
logger.debug("detected Textual Inversion embeddings: %s", keys)
|
||||||
|
inversion_format = "embeddings"
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"unknown Textual Inversion format, no recognized keys: %s", keys
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if inversion_format == "concept":
|
||||||
# separate token and the embeds
|
# separate token and the embeds
|
||||||
trained_token = list(loaded_embeds.keys())[0]
|
token = list(loaded_embeds.keys())[0]
|
||||||
|
|
||||||
layer = loaded_embeds[trained_token].numpy().astype(dtype)
|
layer = loaded_embeds[token].numpy().astype(dtype)
|
||||||
layer *= weight
|
layer *= weight
|
||||||
|
|
||||||
if base_token in embeds:
|
if base_token in embeds:
|
||||||
|
@ -70,17 +73,12 @@ def blend_textual_inversions(
|
||||||
else:
|
else:
|
||||||
embeds[token] = layer
|
embeds[token] = layer
|
||||||
elif inversion_format == "embeddings":
|
elif inversion_format == "embeddings":
|
||||||
loaded_embeds = load_tensor(name, map_location=device)
|
|
||||||
if loaded_embeds is None:
|
|
||||||
logger.warning("unable to load tensor")
|
|
||||||
continue
|
|
||||||
|
|
||||||
string_to_token = loaded_embeds["string_to_token"]
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
string_to_param = loaded_embeds["string_to_param"]
|
string_to_param = loaded_embeds["string_to_param"]
|
||||||
|
|
||||||
# separate token and embeds
|
# separate token and embeds
|
||||||
trained_token = list(string_to_token.keys())[0]
|
token = list(string_to_token.keys())[0]
|
||||||
trained_embeds = string_to_param[trained_token]
|
trained_embeds = string_to_param[token]
|
||||||
|
|
||||||
num_tokens = trained_embeds.shape[0]
|
num_tokens = trained_embeds.shape[0]
|
||||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||||
|
|
|
@ -200,9 +200,7 @@ def remove_prefix(name: str, prefix: str) -> str:
|
||||||
|
|
||||||
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug("loading tensor with Torch JIT: %s", name)
|
||||||
"loading tensor with Torch JIT: %s", name
|
|
||||||
)
|
|
||||||
checkpoint = torch.jit.load(name)
|
checkpoint = torch.jit.load(name)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
@ -246,7 +244,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("error loading pickle tensor: %s", e)
|
logger.warning("error loading pickle tensor: %s", e)
|
||||||
elif extension in ["onnx", "pt"]:
|
elif extension in ["onnx", "pt"]:
|
||||||
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
|
logger.warning(
|
||||||
|
"tensor has ONNX extension, falling back to PyTorch: %s", extension
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
checkpoint = load_torch(name, map_location=map_location)
|
checkpoint = load_torch(name, map_location=map_location)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -226,8 +226,7 @@ def load_pipeline(
|
||||||
if loras is not None and len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
lora_models = [
|
lora_models = [
|
||||||
path.join(server.model_path, "lora", name)
|
path.join(server.model_path, "lora", name) for name in lora_names
|
||||||
for name in lora_names
|
|
||||||
]
|
]
|
||||||
logger.info(
|
logger.info(
|
||||||
"blending base model %s with LoRA models: %s", model, lora_models
|
"blending base model %s with LoRA models: %s", model, lora_models
|
||||||
|
|
|
@ -22,7 +22,9 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
setproctitle("onnx-web worker: %s" % (context.device.device))
|
setproctitle("onnx-web worker: %s" % (context.device.device))
|
||||||
|
|
||||||
logger.trace("checking in from worker with providers: %s", get_available_providers())
|
logger.trace(
|
||||||
|
"checking in from worker with providers: %s", get_available_providers()
|
||||||
|
)
|
||||||
|
|
||||||
# make leaking workers easier to recycle
|
# make leaking workers easier to recycle
|
||||||
context.progress.cancel_join_thread()
|
context.progress.cancel_join_thread()
|
||||||
|
|
|
@ -108,6 +108,16 @@ $defs:
|
||||||
format:
|
format:
|
||||||
type: string
|
type: string
|
||||||
enum: [ckpt, safetensors]
|
enum: [ckpt, safetensors]
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
enum: [
|
||||||
|
# inversion
|
||||||
|
concept,
|
||||||
|
embeddings,
|
||||||
|
# lora
|
||||||
|
cloneofsimo,
|
||||||
|
sd-scripts
|
||||||
|
]
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
source:
|
source:
|
||||||
|
|
Loading…
Reference in New Issue