1
0
Fork 0

feat(api): detect Textual Inversion type from keys (#262)

This commit is contained in:
Sean Sube 2023-03-19 20:16:52 -05:00
parent e19e36ae22
commit 0732058aa8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 55 additions and 39 deletions

View File

@ -6,6 +6,7 @@ from sys import exit
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from huggingface_hub.file_download import hf_hub_download
from jsonschema import ValidationError, validate from jsonschema import ValidationError, validate
from onnx import load_model, save_model from onnx import load_model, save_model
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
@ -216,17 +217,24 @@ def convert_models(ctx: ConversionContext, args, models: Models):
logger.info("skipping network: %s", name) logger.info("skipping network: %s", name)
else: else:
network_format = source_format(network) network_format = source_format(network)
network_model = network.get("model", None)
network_type = network["type"] network_type = network["type"]
source = network["source"] source = network["source"]
try: try:
dest = fetch_model( if network_type == "inversion" and network_model == "concept":
ctx, dest = hf_hub_download(
name, repo_id=source, filename="learned_embeds.bin"
source, )
dest=path.join(ctx.model_path, network_type), else:
format=network_format, dest = fetch_model(
) ctx,
name,
source,
dest=path.join(ctx.model_path, network_type),
format=network_format,
)
logger.info("finished downloading network: %s -> %s", source, dest) logger.info("finished downloading network: %s -> %s", source, dest)
except Exception: except Exception:
logger.exception("error fetching network %s", name) logger.exception("error fetching network %s", name)

View File

@ -13,7 +13,6 @@ from onnx.external_data_helper import (
write_external_data_tensors, write_external_data_tensors,
) )
from onnxruntime import InferenceSession, OrtValue, SessionOptions from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file
from ...server.context import ServerContext from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor from ..utils import ConversionContext, load_tensor

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from huggingface_hub.file_download import hf_hub_download
from onnx import ModelProto, load_model, numpy_helper, save_model from onnx import ModelProto, load_model, numpy_helper, save_model
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
@ -28,12 +27,9 @@ def blend_textual_inversions(
for name, weight, base_token, inversion_format in inversions: for name, weight, base_token, inversion_format in inversions:
if base_token is None: if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name base_token = name
if inversion_format is None:
# TODO: detect concept format
inversion_format = "embeddings"
logger.info( logger.info(
"blending Textual Inversion %s with weight of %s for token %s", "blending Textual Inversion %s with weight of %s for token %s",
name, name,
@ -41,23 +37,30 @@ def blend_textual_inversions(
base_token, base_token,
) )
if inversion_format == "concept": loaded_embeds = load_tensor(name, map_location=device)
# TODO: this should be done in fetch, maybe if loaded_embeds is None:
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") logger.warning("unable to load tensor")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed continue
with open(token_file, "r") as f: if inversion_format is None:
token = f.read() keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
loaded_embeds = load_tensor(embeds_file, map_location=device) logger.debug("detected Textual Inversion concept: %s", keys)
if loaded_embeds is None: inversion_format = "concept"
logger.warning("unable to load tensor") elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion embeddings: %s", keys)
inversion_format = "embeddings"
else:
logger.error(
"unknown Textual Inversion format, no recognized keys: %s", keys
)
continue continue
if inversion_format == "concept":
# separate token and the embeds # separate token and the embeds
trained_token = list(loaded_embeds.keys())[0] token = list(loaded_embeds.keys())[0]
layer = loaded_embeds[trained_token].numpy().astype(dtype) layer = loaded_embeds[token].numpy().astype(dtype)
layer *= weight layer *= weight
if base_token in embeds: if base_token in embeds:
@ -70,17 +73,12 @@ def blend_textual_inversions(
else: else:
embeds[token] = layer embeds[token] = layer
elif inversion_format == "embeddings": elif inversion_format == "embeddings":
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
string_to_token = loaded_embeds["string_to_token"] string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"] string_to_param = loaded_embeds["string_to_param"]
# separate token and embeds # separate token and embeds
trained_token = list(string_to_token.keys())[0] token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[trained_token] trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0] num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name) logger.debug("generating %s layer tokens for %s", num_tokens, name)

View File

@ -200,9 +200,7 @@ def remove_prefix(name: str, prefix: str) -> str:
def load_torch(name: str, map_location=None) -> Optional[Dict]: def load_torch(name: str, map_location=None) -> Optional[Dict]:
try: try:
logger.debug( logger.debug("loading tensor with Torch JIT: %s", name)
"loading tensor with Torch JIT: %s", name
)
checkpoint = torch.jit.load(name) checkpoint = torch.jit.load(name)
except Exception: except Exception:
logger.exception( logger.exception(
@ -246,7 +244,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
except Exception as e: except Exception as e:
logger.warning("error loading pickle tensor: %s", e) logger.warning("error loading pickle tensor: %s", e)
elif extension in ["onnx", "pt"]: elif extension in ["onnx", "pt"]:
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension) logger.warning(
"tensor has ONNX extension, falling back to PyTorch: %s", extension
)
try: try:
checkpoint = load_torch(name, map_location=map_location) checkpoint = load_torch(name, map_location=map_location)
except Exception as e: except Exception as e:

View File

@ -226,8 +226,7 @@ def load_pipeline(
if loras is not None and len(loras) > 0: if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras) lora_names, lora_weights = zip(*loras)
lora_models = [ lora_models = [
path.join(server.model_path, "lora", name) path.join(server.model_path, "lora", name) for name in lora_names
for name in lora_names
] ]
logger.info( logger.info(
"blending base model %s with LoRA models: %s", model, lora_models "blending base model %s with LoRA models: %s", model, lora_models

View File

@ -22,7 +22,9 @@ def worker_main(context: WorkerContext, server: ServerContext):
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device)) setproctitle("onnx-web worker: %s" % (context.device.device))
logger.trace("checking in from worker with providers: %s", get_available_providers()) logger.trace(
"checking in from worker with providers: %s", get_available_providers()
)
# make leaking workers easier to recycle # make leaking workers easier to recycle
context.progress.cancel_join_thread() context.progress.cancel_join_thread()

View File

@ -108,6 +108,16 @@ $defs:
format: format:
type: string type: string
enum: [ckpt, safetensors] enum: [ckpt, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name: name:
type: string type: string
source: source: