feat(api): convert Textual Inversion weights
This commit is contained in:
parent
947a1bfdec
commit
a31f7b9e1f
|
@ -1,5 +1,10 @@
|
||||||
{
|
{
|
||||||
"diffusion": [
|
"diffusion": [
|
||||||
|
{
|
||||||
|
"name": "diffusion-ugly-sonic",
|
||||||
|
"source": "runwayml/stable-diffusion-v1-5",
|
||||||
|
"inversion": "sd-concepts-library/ugly-sonic"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "diffusion-knollingcase",
|
"name": "diffusion-knollingcase",
|
||||||
"source": "Aybeeceedee/knollingcase"
|
"source": "Aybeeceedee/knollingcase"
|
||||||
|
|
|
@ -10,8 +10,9 @@ from jsonschema import ValidationError, validate
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
from .correction_gfpgan import convert_correction_gfpgan
|
from .correction_gfpgan import convert_correction_gfpgan
|
||||||
from .diffusion_original import convert_diffusion_original
|
from .diffusion.original import convert_diffusion_original
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||||
|
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
|
||||||
from .upscale_resrgan import convert_upscale_resrgan
|
from .upscale_resrgan import convert_upscale_resrgan
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
|
@ -216,6 +217,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
ctx, name, model["source"], model_format=model_format
|
ctx, name, model["source"], model_format=model_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "inversion" in model:
|
||||||
|
convert_diffusion_textual_inversion(ctx, source, model["inversion"])
|
||||||
|
|
||||||
if model_format in model_formats_original:
|
if model_format in model_formats_original:
|
||||||
convert_diffusion_original(
|
convert_diffusion_original(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -223,7 +227,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
convert_diffusion_stable(
|
convert_diffusion_diffusers(
|
||||||
ctx,
|
ctx,
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
|
|
|
@ -25,12 +25,11 @@ from diffusers import (
|
||||||
from onnx import load, save_model
|
from onnx import load, save_model
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
from onnx_web.diffusion.load import optimize_pipeline
|
from ...diffusion.load import optimize_pipeline
|
||||||
|
from ...diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from .utils import ConversionContext
|
from ..utils import ConversionContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -63,7 +62,7 @@ def onnx_export(
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_stable(
|
def convert_diffusion_diffusers(
|
||||||
ctx: ConversionContext,
|
ctx: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
source: str,
|
|
@ -53,8 +53,8 @@ from transformers import (
|
||||||
CLIPVisionConfig,
|
CLIPVisionConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
from .diffusers import convert_diffusion_diffusers
|
||||||
from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -1428,5 +1428,5 @@ def convert_diffusion_original(
|
||||||
if "vae" in model:
|
if "vae" in model:
|
||||||
del model["vae"]
|
del model["vae"]
|
||||||
|
|
||||||
convert_diffusion_stable(ctx, model, working_name)
|
convert_diffusion_diffusers(ctx, model, working_name)
|
||||||
logger.info("ONNX pipeline saved to %s", name)
|
logger.info("ONNX pipeline saved to %s", name)
|
|
@ -0,0 +1,88 @@
|
||||||
|
from os import mkdir, path
|
||||||
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
from torch.onnx import export
|
||||||
|
from sys import argv
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from ..utils import ConversionContext, sanitize_name
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str):
|
||||||
|
cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}")
|
||||||
|
logger.info("converting textual inversion: %s -> %s", inversion, cache_path)
|
||||||
|
|
||||||
|
if not path.exists(cache_path):
|
||||||
|
mkdir(cache_path)
|
||||||
|
|
||||||
|
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
||||||
|
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
||||||
|
|
||||||
|
with open(token_file, "r") as f:
|
||||||
|
token = f.read()
|
||||||
|
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
subfolder="tokenizer",
|
||||||
|
)
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
subfolder="text_encoder",
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
||||||
|
|
||||||
|
# separate token and the embeds
|
||||||
|
trained_token = list(loaded_embeds.keys())[0]
|
||||||
|
embeds = loaded_embeds[trained_token]
|
||||||
|
|
||||||
|
# cast to dtype of text_encoder
|
||||||
|
dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||||
|
embeds.to(dtype)
|
||||||
|
|
||||||
|
# add the token in tokenizer
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token)
|
||||||
|
if num_added_tokens == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
# resize the token embeddings
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
# get the id for the token and assign the embeds
|
||||||
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||||
|
|
||||||
|
# conversion stuff
|
||||||
|
text_input = tokenizer(
|
||||||
|
"A sample prompt",
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
export(
|
||||||
|
text_encoder,
|
||||||
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
|
(
|
||||||
|
text_input.input_ids.to(device=context.training_device, dtype=torch.int32)
|
||||||
|
),
|
||||||
|
f=path.join(cache_path, "text_encoder", "model.onnx"),
|
||||||
|
input_names=["input_ids"],
|
||||||
|
output_names=["last_hidden_state", "pooler_output"],
|
||||||
|
dynamic_axes={
|
||||||
|
"input_ids": {0: "batch", 1: "sequence"},
|
||||||
|
},
|
||||||
|
do_constant_folding=True,
|
||||||
|
opset_version=context.opset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
context = ConversionContext.from_environ()
|
||||||
|
convert_diffusion_textual_inversion(context, argv[1], argv[2])
|
|
@ -122,8 +122,9 @@ chain_stages = {
|
||||||
available_platforms: List[DeviceParams] = []
|
available_platforms: List[DeviceParams] = []
|
||||||
|
|
||||||
# loaded from model_path
|
# loaded from model_path
|
||||||
diffusion_models: List[str] = []
|
|
||||||
correction_models: List[str] = []
|
correction_models: List[str] = []
|
||||||
|
diffusion_models: List[str] = []
|
||||||
|
inversion_models: List[str] = []
|
||||||
upscaling_models: List[str] = []
|
upscaling_models: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -301,8 +302,9 @@ def get_model_name(model: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def load_models(context: ServerContext) -> None:
|
def load_models(context: ServerContext) -> None:
|
||||||
global diffusion_models
|
|
||||||
global correction_models
|
global correction_models
|
||||||
|
global diffusion_models
|
||||||
|
global inversion_models
|
||||||
global upscaling_models
|
global upscaling_models
|
||||||
|
|
||||||
diffusion_models = [
|
diffusion_models = [
|
||||||
|
@ -323,6 +325,12 @@ def load_models(context: ServerContext) -> None:
|
||||||
correction_models = list(set(correction_models))
|
correction_models = list(set(correction_models))
|
||||||
correction_models.sort()
|
correction_models.sort()
|
||||||
|
|
||||||
|
inversion_models = [
|
||||||
|
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
|
||||||
|
]
|
||||||
|
inversion_models = list(set(inversion_models))
|
||||||
|
inversion_models.sort()
|
||||||
|
|
||||||
upscaling_models = [
|
upscaling_models = [
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
||||||
]
|
]
|
||||||
|
@ -496,8 +504,9 @@ def list_mask_filters():
|
||||||
def list_models():
|
def list_models():
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"diffusion": diffusion_models,
|
|
||||||
"correction": correction_models,
|
"correction": correction_models,
|
||||||
|
"diffusion": diffusion_models,
|
||||||
|
"inversion": inversion_models,
|
||||||
"upscaling": upscaling_models,
|
"upscaling": upscaling_models,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue