diff --git a/api/extras.json b/api/extras.json index 3fa65e5b..4371eeb6 100644 --- a/api/extras.json +++ b/api/extras.json @@ -1,5 +1,10 @@ { "diffusion": [ + { + "name": "diffusion-ugly-sonic", + "source": "runwayml/stable-diffusion-v1-5", + "inversion": "sd-concepts-library/ugly-sonic" + }, { "name": "diffusion-knollingcase", "source": "Aybeeceedee/knollingcase" diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 774a9e72..eab146fb 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -10,8 +10,9 @@ from jsonschema import ValidationError, validate from yaml import safe_load from .correction_gfpgan import convert_correction_gfpgan -from .diffusion_original import convert_diffusion_original -from .diffusion_stable import convert_diffusion_stable +from .diffusion.original import convert_diffusion_original +from .diffusion.diffusers import convert_diffusion_diffusers +from .diffusion.textual_inversion import convert_diffusion_textual_inversion from .upscale_resrgan import convert_upscale_resrgan from .utils import ( ConversionContext, @@ -216,6 +217,9 @@ def convert_models(ctx: ConversionContext, args, models: Models): ctx, name, model["source"], model_format=model_format ) + if "inversion" in model: + convert_diffusion_textual_inversion(ctx, source, model["inversion"]) + if model_format in model_formats_original: convert_diffusion_original( ctx, @@ -223,7 +227,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): source, ) else: - convert_diffusion_stable( + convert_diffusion_diffusers( ctx, model, source, diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion/diffusers.py similarity index 98% rename from api/onnx_web/convert/diffusion_stable.py rename to api/onnx_web/convert/diffusion/diffusers.py index 998e2fb5..0a805dcb 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -25,12 +25,11 @@ from diffusers import ( from onnx import load, save_model from torch.onnx import export -from onnx_web.diffusion.load import optimize_pipeline - -from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( +from ...diffusion.load import optimize_pipeline +from ...diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) -from .utils import ConversionContext +from ..utils import ConversionContext logger = getLogger(__name__) @@ -63,7 +62,7 @@ def onnx_export( @torch.no_grad() -def convert_diffusion_stable( +def convert_diffusion_diffusers( ctx: ConversionContext, model: Dict, source: str, diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion/original.py similarity index 99% rename from api/onnx_web/convert/diffusion_original.py rename to api/onnx_web/convert/diffusion/original.py index da00c45b..b75be023 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -53,8 +53,8 @@ from transformers import ( CLIPVisionConfig, ) -from .diffusion_stable import convert_diffusion_stable -from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name +from .diffusers import convert_diffusion_diffusers +from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name logger = getLogger(__name__) @@ -1428,5 +1428,5 @@ def convert_diffusion_original( if "vae" in model: del model["vae"] - convert_diffusion_stable(ctx, model, working_name) + convert_diffusion_diffusers(ctx, model, working_name) logger.info("ONNX pipeline saved to %s", name) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py new file mode 100644 index 00000000..24f5a1bc --- /dev/null +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -0,0 +1,88 @@ +from os import mkdir, path +from huggingface_hub.file_download import hf_hub_download +from transformers import CLIPTokenizer, CLIPTextModel +from torch.onnx import export +from sys import argv +from logging import getLogger + +from ..utils import ConversionContext, sanitize_name + +import torch + +logger = getLogger(__name__) + + +def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str): + cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}") + logger.info("converting textual inversion: %s -> %s", inversion, cache_path) + + if not path.exists(cache_path): + mkdir(cache_path) + + embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") + token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") + + with open(token_file, "r") as f: + token = f.read() + + tokenizer = CLIPTokenizer.from_pretrained( + base_model, + subfolder="tokenizer", + ) + text_encoder = CLIPTextModel.from_pretrained( + base_model, + subfolder="text_encoder", + ) + + loaded_embeds = torch.load(embeds_file, map_location=context.map_location) + + # separate token and the embeds + trained_token = list(loaded_embeds.keys())[0] + embeds = loaded_embeds[trained_token] + + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + embeds.to(dtype) + + # add the token in tokenizer + num_added_tokens = tokenizer.add_tokens(token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." + ) + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + + # conversion stuff + text_input = tokenizer( + "A sample prompt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + export( + text_encoder, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + ( + text_input.input_ids.to(device=context.training_device, dtype=torch.int32) + ), + f=path.join(cache_path, "text_encoder", "model.onnx"), + input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + do_constant_folding=True, + opset_version=context.opset, + ) + +if __name__ == "__main__": + context = ConversionContext.from_environ() + convert_diffusion_textual_inversion(context, argv[1], argv[2]) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index ee60ef84..a2ee96c0 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -122,8 +122,9 @@ chain_stages = { available_platforms: List[DeviceParams] = [] # loaded from model_path -diffusion_models: List[str] = [] correction_models: List[str] = [] +diffusion_models: List[str] = [] +inversion_models: List[str] = [] upscaling_models: List[str] = [] @@ -301,8 +302,9 @@ def get_model_name(model: str) -> str: def load_models(context: ServerContext) -> None: - global diffusion_models global correction_models + global diffusion_models + global inversion_models global upscaling_models diffusion_models = [ @@ -323,6 +325,12 @@ def load_models(context: ServerContext) -> None: correction_models = list(set(correction_models)) correction_models.sort() + inversion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) + ] + inversion_models = list(set(inversion_models)) + inversion_models.sort() + upscaling_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) ] @@ -496,8 +504,9 @@ def list_mask_filters(): def list_models(): return jsonify( { - "diffusion": diffusion_models, "correction": correction_models, + "diffusion": diffusion_models, + "inversion": inversion_models, "upscaling": upscaling_models, } )