feat(api): convert Textual Inversion weights
This commit is contained in:
parent
947a1bfdec
commit
a31f7b9e1f
|
@ -1,5 +1,10 @@
|
|||
{
|
||||
"diffusion": [
|
||||
{
|
||||
"name": "diffusion-ugly-sonic",
|
||||
"source": "runwayml/stable-diffusion-v1-5",
|
||||
"inversion": "sd-concepts-library/ugly-sonic"
|
||||
},
|
||||
{
|
||||
"name": "diffusion-knollingcase",
|
||||
"source": "Aybeeceedee/knollingcase"
|
||||
|
|
|
@ -10,8 +10,9 @@ from jsonschema import ValidationError, validate
|
|||
from yaml import safe_load
|
||||
|
||||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .diffusion_original import convert_diffusion_original
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .diffusion.original import convert_diffusion_original
|
||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
|
||||
from .upscale_resrgan import convert_upscale_resrgan
|
||||
from .utils import (
|
||||
ConversionContext,
|
||||
|
@ -216,6 +217,9 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
ctx, name, model["source"], model_format=model_format
|
||||
)
|
||||
|
||||
if "inversion" in model:
|
||||
convert_diffusion_textual_inversion(ctx, source, model["inversion"])
|
||||
|
||||
if model_format in model_formats_original:
|
||||
convert_diffusion_original(
|
||||
ctx,
|
||||
|
@ -223,7 +227,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
source,
|
||||
)
|
||||
else:
|
||||
convert_diffusion_stable(
|
||||
convert_diffusion_diffusers(
|
||||
ctx,
|
||||
model,
|
||||
source,
|
||||
|
|
|
@ -25,12 +25,11 @@ from diffusers import (
|
|||
from onnx import load, save_model
|
||||
from torch.onnx import export
|
||||
|
||||
from onnx_web.diffusion.load import optimize_pipeline
|
||||
|
||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||
from ...diffusion.load import optimize_pipeline
|
||||
from ...diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .utils import ConversionContext
|
||||
from ..utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -63,7 +62,7 @@ def onnx_export(
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_stable(
|
||||
def convert_diffusion_diffusers(
|
||||
ctx: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
|
@ -53,8 +53,8 @@ from transformers import (
|
|||
CLIPVisionConfig,
|
||||
)
|
||||
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
||||
from .diffusers import convert_diffusion_diffusers
|
||||
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -1428,5 +1428,5 @@ def convert_diffusion_original(
|
|||
if "vae" in model:
|
||||
del model["vae"]
|
||||
|
||||
convert_diffusion_stable(ctx, model, working_name)
|
||||
convert_diffusion_diffusers(ctx, model, working_name)
|
||||
logger.info("ONNX pipeline saved to %s", name)
|
|
@ -0,0 +1,88 @@
|
|||
from os import mkdir, path
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from torch.onnx import export
|
||||
from sys import argv
|
||||
from logging import getLogger
|
||||
|
||||
from ..utils import ConversionContext, sanitize_name
|
||||
|
||||
import torch
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str):
|
||||
cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}")
|
||||
logger.info("converting textual inversion: %s -> %s", inversion, cache_path)
|
||||
|
||||
if not path.exists(cache_path):
|
||||
mkdir(cache_path)
|
||||
|
||||
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
||||
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
||||
|
||||
with open(token_file, "r") as f:
|
||||
token = f.read()
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
base_model,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
base_model,
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
|
||||
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
||||
|
||||
# separate token and the embeds
|
||||
trained_token = list(loaded_embeds.keys())[0]
|
||||
embeds = loaded_embeds[trained_token]
|
||||
|
||||
# cast to dtype of text_encoder
|
||||
dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||
embeds.to(dtype)
|
||||
|
||||
# add the token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# resize the token embeddings
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# get the id for the token and assign the embeds
|
||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||
|
||||
# conversion stuff
|
||||
text_input = tokenizer(
|
||||
"A sample prompt",
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
export(
|
||||
text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
(
|
||||
text_input.input_ids.to(device=context.training_device, dtype=torch.int32)
|
||||
),
|
||||
f=path.join(cache_path, "text_encoder", "model.onnx"),
|
||||
input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
do_constant_folding=True,
|
||||
opset_version=context.opset,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
context = ConversionContext.from_environ()
|
||||
convert_diffusion_textual_inversion(context, argv[1], argv[2])
|
|
@ -122,8 +122,9 @@ chain_stages = {
|
|||
available_platforms: List[DeviceParams] = []
|
||||
|
||||
# loaded from model_path
|
||||
diffusion_models: List[str] = []
|
||||
correction_models: List[str] = []
|
||||
diffusion_models: List[str] = []
|
||||
inversion_models: List[str] = []
|
||||
upscaling_models: List[str] = []
|
||||
|
||||
|
||||
|
@ -301,8 +302,9 @@ def get_model_name(model: str) -> str:
|
|||
|
||||
|
||||
def load_models(context: ServerContext) -> None:
|
||||
global diffusion_models
|
||||
global correction_models
|
||||
global diffusion_models
|
||||
global inversion_models
|
||||
global upscaling_models
|
||||
|
||||
diffusion_models = [
|
||||
|
@ -323,6 +325,12 @@ def load_models(context: ServerContext) -> None:
|
|||
correction_models = list(set(correction_models))
|
||||
correction_models.sort()
|
||||
|
||||
inversion_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
|
||||
]
|
||||
inversion_models = list(set(inversion_models))
|
||||
inversion_models.sort()
|
||||
|
||||
upscaling_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
||||
]
|
||||
|
@ -496,8 +504,9 @@ def list_mask_filters():
|
|||
def list_models():
|
||||
return jsonify(
|
||||
{
|
||||
"diffusion": diffusion_models,
|
||||
"correction": correction_models,
|
||||
"diffusion": diffusion_models,
|
||||
"inversion": inversion_models,
|
||||
"upscaling": upscaling_models,
|
||||
}
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue