diff --git a/api/extras.json b/api/extras.json index 2c1dc374..3fa65e5b 100644 --- a/api/extras.json +++ b/api/extras.json @@ -1,9 +1,23 @@ { "diffusion": [ - ["diffusion-knollingcase", "Aybeeceedee/knollingcase"], - ["diffusion-openjourney", "prompthero/openjourney"], - ["diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"], - ["diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"] + { + "name": "diffusion-knollingcase", + "source": "Aybeeceedee/knollingcase" + }, + { + "name": "diffusion-openjourney", + "source": "prompthero/openjourney" + }, + { + "name": "diffusion-stablydiffused-aesthetic-v2-6", + "source": "civitai://6266?type=Pruned%20Model&format=SafeTensor", + "format": "safetensors" + }, + { + "name": "diffusion-unstable-ink-dream-v6", + "source": "civitai://5796", + "format": "safetensors" + } ], "correction": [], "upscaling": [] diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 13e6d7a6..86f13659 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -1,9 +1,3 @@ -from .correction_gfpgan import convert_correction_gfpgan -from .diffusion_original import convert_diffusion_original -from .diffusion_stable import convert_diffusion_stable -from .upscale_resrgan import convert_upscale_resrgan -from .utils import ConversionContext - import warnings from argparse import ArgumentParser from json import loads @@ -11,9 +5,17 @@ from logging import getLogger from os import environ, makedirs, path from sys import exit from typing import Dict, List, Optional, Tuple +from yaml import safe_load +from jsonschema import validate, ValidationError import torch +from .correction_gfpgan import convert_correction_gfpgan +from .diffusion_original import convert_diffusion_original +from .diffusion_stable import convert_diffusion_stable +from .upscale_resrgan import convert_upscale_resrgan +from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling + # suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 warnings.filterwarnings( "ignore", ".*The shape inference of prim::Constant type is missing.*" @@ -29,20 +31,39 @@ Models = Dict[str, List[Tuple[str, str, Optional[int]]]] logger = getLogger(__name__) +model_sources: Dict[str, Tuple[str, str]] = { + "civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"), +} + +model_source_huggingface = "huggingface://" + # recommended models base_models: Models = { "diffusion": [ # v1.x - ("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"), - ("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"), - # v2.x - ("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"), ( - "stable-diffusion-onnx-v2-inpainting", - "stabilityai/stable-diffusion-2-inpainting", + "stable-diffusion-onnx-v1-5", + model_source_huggingface + "runwayml/stable-diffusion-v1-5", ), + # ( + # "stable-diffusion-onnx-v1-inpainting", + # model_source_huggingface + "runwayml/stable-diffusion-inpainting", + # ), + # v2.x + # ( + # "stable-diffusion-onnx-v2-1", + # model_source_huggingface + "stabilityai/stable-diffusion-2-1", + # ), + # ( + # "stable-diffusion-onnx-v2-inpainting", + # model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting", + # ), # TODO: should have its own converter - ("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"), + ( + "upscaling-stable-diffusion-x4", + model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler", + True, + ), ], "correction": [ ( @@ -79,35 +100,86 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")) training_device = "cuda" if torch.cuda.is_available() else "cpu" -def load_models(args, ctx: ConversionContext, models: Models): +def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str: + cache_name = path.join(ctx.cache_path, name) + if format is not None: + # add an extension if possible, some of the conversion code checks for it + cache_name = "%s.%s" % (cache_name, format) + + for proto in model_sources: + api_name, api_root = model_sources.get(proto) + if source.startswith(proto): + api_source = api_root % (source.removeprefix(proto)) + logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name) + return download_progress([(api_source, cache_name)]) + + if source.startswith(model_source_huggingface): + hub_source = source.removeprefix(model_source_huggingface) + logger.info("Downloading model from Huggingface Hub: %s", hub_source) + # from_pretrained has a bunch of useful logic that snapshot_download by itself down not + return hub_source + elif source.startswith("https://"): + logger.info("Downloading model from: %s", source) + return download_progress([(source, cache_name)]) + elif source.startswith("http://"): + logger.warning("Downloading model from insecure source: %s", source) + return download_progress([(source, cache_name)]) + elif source.startswith(path.sep) or source.startswith("."): + logger.info("Using local model: %s", source) + return source + else: + logger.info("Unknown model location, using path as provided: %s", source) + return source + + +def convert_models(ctx: ConversionContext, args, models: Models): if args.diffusion: - for source in models.get("diffusion"): - name, file = source + for model in models.get("diffusion"): + model = tuple_to_diffusion(model) + name = model.get("name") + if name in args.skip: - logger.info("Skipping model: %s", source[0]) + logger.info("Skipping model: %s", name) else: - if file.endswith(".safetensors") or file.endswith(".ckpt"): - convert_diffusion_original(ctx, *source, args.opset, args.half) + format = source_format(model) + source = fetch_model(ctx, name, model["source"], format=format) + + if format in ["safetensors", "ckpt"]: + convert_diffusion_original( + ctx, + model, + source, + ) else: - # TODO: make this a parameter in the JSON/dict - single_vae = "upscaling" in source[0] convert_diffusion_stable( - ctx, *source, args.opset, args.half, args.token, single_vae=single_vae + ctx, + model, + source, ) if args.upscaling: - for source in models.get("upscaling"): - if source[0] in args.skip: - logger.info("Skipping model: %s", source[0]) + for model in models.get("upscaling"): + model = tuple_to_upscaling(model) + name = model.get("name") + + if name in args.skip: + logger.info("Skipping model: %s", name) else: - convert_upscale_resrgan(ctx, *source, args.opset) + format = source_format(model) + source = fetch_model(ctx, name, model["source"], format=format) + convert_upscale_resrgan(ctx, model, source) if args.correction: - for source in models.get("correction"): - if source[0] in args.skip: - logger.info("Skipping model: %s", source[0]) + for model in models.get("correction"): + model = tuple_to_correction(model) + name = model.get("name") + + if name in args.skip: + logger.info("Skipping model: %s", name) else: - convert_correction_gfpgan(ctx, *source, args.opset) + format = source_format(model) + source = fetch_model(ctx, name, model["source"], format=format) + convert_correction_gfpgan(ctx, model, source) def main() -> int: @@ -146,7 +218,7 @@ def main() -> int: args = parser.parse_args() logger.info("CLI arguments: %s", args) - ctx = ConversionContext(model_path, training_device) + ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) if not path.exists(model_path): @@ -154,16 +226,26 @@ def main() -> int: makedirs(model_path) logger.info("Converting base models.") - load_models(args, ctx, base_models) + convert_models(ctx, args, base_models) for file in args.extras: if file is not None and file != "": logger.info("Loading extra models from %s", file) try: with open(file, "r") as f: - data = loads(f.read()) + data = safe_load(f.read()) + + with open("./schemas/extras.yaml", "r") as f: + schema = safe_load(f.read()) + + logger.debug("validating chain request: %s against %s", data, schema) + + try: + validate(data, schema) logger.info("Converting extra models.") - load_models(args, ctx, data) + convert_models(ctx, args, data) + except ValidationError as err: + logger.error("Invalid data in extras file: %s", err) except Exception as err: logger.error("Error converting extra models: %s", err) diff --git a/api/onnx_web/convert/correction_gfpgan.py b/api/onnx_web/convert/correction_gfpgan.py index 13a32c8b..d0ec6d18 100644 --- a/api/onnx_web/convert/correction_gfpgan.py +++ b/api/onnx_web/convert/correction_gfpgan.py @@ -1,31 +1,34 @@ -import torch +from logging import getLogger +from os import path from shutil import copyfile + +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from torch.onnx import export -from os import path -from logging import getLogger -from basicsr.archs.rrdbnet_arch import RRDBNet -from .utils import ConversionContext + +from .utils import ConversionContext, ModelDict logger = getLogger(__name__) -@torch.no_grad() -def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int): - dest_path = path.join(ctx.model_path, name + ".pth") - dest_onnx = path.join(ctx.model_path, name + ".onnx") - logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx) - if path.isfile(dest_onnx): +@torch.no_grad() +def convert_correction_gfpgan( + ctx: ConversionContext, + model: ModelDict, + source: str, +): + name = model.get("name") + source = source or model.get("source") + scale = model.get("scale") + + dest = path.join(ctx.model_path, name + ".onnx") + logger.info("converting GFPGAN model: %s -> %s", name, dest) + + if path.isfile(dest): logger.info("ONNX model already exists, skipping.") return - if not path.isfile(dest_path): - logger.info("PTH model not found, downloading...") - download_path = load_file_from_url( - url=url, model_dir=dest_path + "-cache", progress=True, file_name=None - ) - copyfile(download_path, dest_path) - logger.info("loading and training model") model = RRDBNet( num_in_ch=3, @@ -36,7 +39,7 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale scale=scale, ) - torch_model = torch.load(dest_path, map_location=ctx.map_location) + torch_model = torch.load(source, map_location=ctx.map_location) # TODO: make sure strict=False is safe here if "params_ema" in torch_model: model.load_state_dict(torch_model["params_ema"], strict=False) @@ -54,15 +57,15 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale "output": {2: "width", 3: "height"}, } - logger.info("exporting ONNX model to %s", dest_onnx) + logger.info("exporting ONNX model to %s", dest) export( model, rng, - dest_onnx, + dest, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=opset, + opset_version=ctx.opset, export_params=True, ) logger.info("GFPGAN exported to ONNX successfully.") diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index c39715f5..470ce847 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -11,6 +11,17 @@ # TODO: ask about license before merging ### +import json +import os +import re +import shutil +import traceback +from logging import getLogger +from typing import Dict, List + +import huggingface_hub.utils.tqdm +import safetensors.torch +import torch from diffusers import ( AutoencoderKL, DDIMScheduler, @@ -20,95 +31,34 @@ from diffusers import ( HeunDiscreteScheduler, LDMTextToImagePipeline, LMSDiscreteScheduler, + PaintByExamplePipeline, PNDMScheduler, StableDiffusionPipeline, - UNet2DConditionModel, PaintByExamplePipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import ( + LDMBertConfig, + LDMBertModel, ) -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfApi, hf_hub_download -from logging import getLogger from omegaconf import OmegaConf from pydantic import BaseModel -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig -from typing import Dict, List, Union - -import huggingface_hub.utils.tqdm -import json -import os -import re -import safetensors.torch -import shutil -import torch -import traceback +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionConfig, +) from .diffusion_stable import convert_diffusion_stable -from .utils import ConversionContext +from .utils import ConversionContext, ModelDict logger = getLogger(__name__) -def get_images(): - return [] - -class Concept(BaseModel): - class_data_dir: str = "" - class_guidance_scale: float = 7.5 - class_infer_steps: int = 60 - class_negative_prompt: str = "" - class_prompt: str = "" - class_token: str = "" - instance_data_dir: str = "" - instance_prompt: str = "" - instance_token: str = "" - is_valid: bool = False - n_save_sample: int = 1 - num_class_images: int = 0 - num_class_images_per: int = 0 - sample_seed: int = -1 - save_guidance_scale: float = 7.5 - save_infer_steps: int = 60 - save_sample_negative_prompt: str = "" - save_sample_prompt: str = "" - save_sample_template: str = "" - - def __init__(self, input_dict: Union[Dict, None] = None, **kwargs): - super().__init__(**kwargs) - if input_dict is not None: - self.load_params(input_dict) - if self.is_valid and self.num_class_images != 0: - if self.num_class_images_per == 0: - images = get_images(self.instance_data_dir) - if len(images) < self.num_class_images * 2: - self.num_class_images_per = 1 - else: - self.num_class_images_per = self.num_class_images // len(images) - self.num_class_images = 0 - - def to_dict(self): - return self.dict() - - def to_json(self): - return json.dumps(self.to_dict()) - - def load_params(self, params_dict): - for key, value in params_dict.items(): - if hasattr(self, key): - setattr(self, key, value) - if self.instance_data_dir: - self.is_valid = os.path.isdir(self.instance_data_dir) - else: - self.is_valid = False - - -# Keys to save, replacing our dumb __init__ method -save_keys = [] - -# Keys to return to the ui when Load Settings is clicked. -ui_keys = [] - - def sanitize_name(name): return "".join(x for x in name if (x.isalnum() or x in "._- ")) @@ -200,7 +150,7 @@ class DreamboothConfig(BaseModel): super().__init__(**kwargs) model_name = sanitize_name(model_name) - model_dir = os.path.join(ctx.model_path, model_name) + model_dir = os.path.join(ctx.cache_path, model_name) working_dir = os.path.join(model_dir, "working") if not os.path.exists(working_dir): @@ -214,7 +164,6 @@ class DreamboothConfig(BaseModel): self.scheduler = scheduler self.v2 = v2 - # Actually save as a file def save(self, backup=False): """ Save the config file @@ -236,132 +185,6 @@ class DreamboothConfig(BaseModel): if hasattr(self, key): setattr(self, key, value) - # Pass a dict and return a list of Concept objects - def concepts(self, required: int = -1): - concepts = [] - c_idx = 0 - # If using a file for concepts and not requesting from UI, load from file - if self.use_concepts and self.concepts_path and required == -1: - concepts_list = concepts_from_file(self.concepts_path) - - # Otherwise, use 'stored' list - else: - concepts_list = self.concepts_list - if required == -1: - required = len(concepts_list) - - for concept_dict in concepts_list: - concept = Concept(input_dict=concept_dict) - if concept.is_valid: - if concept.class_data_dir == "" or concept.class_data_dir is None: - concept.class_data_dir = os.path.join(self.model_dir, f"classifiers_{c_idx}") - concepts.append(concept) - c_idx += 1 - - missing = len(concepts) - required - if missing > 0: - concepts.extend([Concept(None)] * missing) - return concepts - - # Set default values - def check_defaults(self): - if self.model_name is not None and self.model_name != "": - if self.revision == "" or self.revision is None: - self.revision = 0 - if self.epoch == "" or self.epoch is None: - self.epoch = 0 - self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- ")) - models_path = "." # TODO: use ctx path - model_dir = os.path.join(models_path, self.model_name) - working_dir = os.path.join(model_dir, "working") - if not os.path.exists(working_dir): - os.makedirs(working_dir) - self.model_dir = model_dir - self.pretrained_model_name_or_path = working_dir - - -def concepts_from_file(concepts_path: str): - concepts = [] - if os.path.exists(concepts_path) and os.path.isfile(str): - try: - with open(concepts_path,"r") as concepts_file: - concepts_str = concepts_file.read() - except Exception as e: - print(f"Exception opening concepts file: {e}") - else: - concepts_str = concepts_path - - try: - concepts_data = json.loads(concepts_str) - for concept_data in concepts_data: - concept = Concept(input_dict=concept_data) - if concept.is_valid: - concepts.append(concept.__dict__) - except Exception as e: - print(f"Exception parsing concepts: {e}") - return concepts - - -def save_config(*args): - raise Exception("where tho") - params = list(args) - concept_keys = ["c1_", "c2_", "c3_", "c4_"] - model_name = params[0] - if model_name is None or model_name == "": - print("Invalid model name.") - return - config = from_file(ctx, model_name) - if config is None: - config = DreamboothConfig(model_name) - params_dict = dict(zip(save_keys, params)) - concepts_list = [] - # If using a concepts file/string, keep concepts_list empty. - if params_dict["db_use_concepts"] and params_dict["db_concepts_path"]: - concepts_list = [] - params_dict["concepts_list"] = concepts_list - else: - for concept_key in concept_keys: - concept_dict = {} - for key, param in params_dict.items(): - if concept_key in key and param is not None: - concept_dict[key.replace(concept_key, "")] = param - concept_test = Concept(concept_dict) - if concept_test.is_valid: - concepts_list.append(concept_test.__dict__) - existing_concepts = params_dict["concepts_list"] if "concepts_list" in params_dict else [] - if len(concepts_list) and not len(existing_concepts): - params_dict["concepts_list"] = concepts_list - - config.load_params(params_dict) - config.save() - - -def from_file(ctx: ConversionContext, model_name): - """ - Load config data from UI - Args: - model_name: The config to load - - Returns: Dict | None - - """ - if model_name == "" or model_name is None: - return None - - model_name = sanitize_name(model_name) - config_file = os.path.join(ctx.model_path, model_name, "db_config.json") - try: - with open(config_file, 'r') as openfile: - config_dict = json.load(openfile) - - config = DreamboothConfig(model_name) - config.load_params(config_dict) - return config - except Exception as e: - print(f"Exception loading config: {e}") - traceback.print_exc() - return None - # coding=utf-8 # Copyright 2022 The HuggingFace Inc. team. @@ -379,8 +202,6 @@ def from_file(ctx: ConversionContext, model_name): # limitations under the License. """ Conversion script for the LDM checkpoints. """ -def get_db_models(): - return [] def shave_segments(path, n_shave_prefix_segments=1): """ @@ -1075,7 +896,7 @@ def convert_open_clip_checkpoint(checkpoint): if 'cond_stage_model.model.text_projection' in checkpoint: d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0]) else: - print("No projection shape found, setting to 1024") + logger.debug("No projection shape found, setting to 1024") d_model = 1024 text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") @@ -1130,7 +951,7 @@ def replace_symlinks(path, base): blob_path = None if blob_path is None: - print("NO BLOB") + logger.debug("NO BLOB") return os.replace(blob_path, path) elif os.path.isdir(path): @@ -1140,7 +961,6 @@ def replace_symlinks(path, base): def download_model(db_config: DreamboothConfig, token): tmp_dir = os.path.join(db_config.model_dir, "src") - working_dir = db_config.pretrained_model_name_or_path hub_url = db_config.src if "http" in hub_url or "huggingface.co" in hub_url: @@ -1155,7 +975,7 @@ def download_model(db_config: DreamboothConfig, token): ) if repo_info.sha is None: - print("Unable to fetch repo?") + logger.warning("Unable to fetch repo?") return None, None siblings = repo_info.siblings @@ -1208,10 +1028,10 @@ def download_model(db_config: DreamboothConfig, token): if files_to_fetch and config_file: files_to_fetch.append(config_file) - print(f"Fetching files: {files_to_fetch}") + logger.info(f"Fetching files: {files_to_fetch}") if not len(files_to_fetch): - print("Nothing to fetch!") + logger.debug("Nothing to fetch!") return None, None @@ -1296,9 +1116,6 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"): return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type) - print("Could not find valid config. Returning default v1 config.") - return get_config_path(model_versions["v1"], train_types["default"], config_base_name, prediction_type="epsilon") - def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="", new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True): @@ -1352,7 +1169,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f if db_config is not None: original_config_file = config if model_info is not None: - print("Got model info.") + logger.debug("Got model info.") if ".ckpt" in model_info or ".safetensors" in model_info: # Set this to false, because we have a checkpoint where we can *maybe* get a revision. from_hub = False @@ -1360,28 +1177,26 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f checkpoint_file = model_info else: msg = "Unable to fetch model from hub." - print(msg) + logger.warning(msg) return "", "", 0, 0, "", "", "", "", image_size, "", msg - reset_safe = False - try: checkpoint = None map_location = torch.device("cpu") # Try to determine if v1 or v2 model if we have a ckpt if not from_hub: - print("Loading model from checkpoint.") + logger.info("Loading model from checkpoint.") _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": os.environ["SAFETENSORS_FAST_GPU"] = "1" try: - print("Loading safetensors...") + logger.debug("Loading safetensors...") checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu") except Exception as e: checkpoint = torch.jit.load(checkpoint_file) else: - print("Loading ckpt...") + logger.debug("Loading ckpt...") checkpoint = torch.load(checkpoint_file, map_location=map_location) checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint @@ -1401,7 +1216,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if not is_512: # v2.1 needs to upcast attention - print("Setting upcast_attention") + logger.debug("Setting upcast_attention") upcast_attention = True v2 = True else: @@ -1410,15 +1225,15 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet") try: unet = UNet2DConditionModel.from_pretrained(unet_dir) - print("Loaded unet.") + logger.debug("Loaded unet.") unet_dict = unet.state_dict() key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight" if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024: - print("We got v2!") + logger.debug("UNet using v2 parameters.") v2 = True except: - print("Exception loading unet!") + logger.error("Exception loading unet!") traceback.print_exc() if v2 and not is_512: @@ -1428,7 +1243,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f original_config_file = get_config_file(train_unfrozen, v2, prediction_type) - print(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") + logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") db_config.resolution = image_size db_config.lifetime_revision = revision db_config.epoch = epoch @@ -1438,7 +1253,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f db_config.save() return - print(f"{'v2' if v2 else 'v1'} model loaded.") + logger.info(f"{'v2' if v2 else 'v1'} model loaded.") # Use existing YAML if present if checkpoint_file is not None: @@ -1447,10 +1262,10 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f original_config_file = config_check if original_config_file is None or not os.path.exists(original_config_file): - print("Unable to select a config file: %s" % (original_config_file)) + logger.warning("Unable to select a config file: %s" % (original_config_file)) return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file." - print(f"Trying to load: {original_config_file}") + logger.debug(f"Trying to load: {original_config_file}") original_config = OmegaConf.load(original_config_file) num_train_timesteps = original_config.model.params.timesteps @@ -1489,7 +1304,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - print("Converting unet...") + logger.info("Converting UNet...") # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["upcast_attention"] = upcast_attention @@ -1501,14 +1316,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f db_config.has_ema = has_ema db_config.save() unet.load_state_dict(converted_unet_checkpoint) - print("Converting vae...") + + logger.info("Converting VAE...") # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) - print("Converting text encoder...") + + logger.info("Converting text encoder...") # Convert the text model. text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] if text_model_type == "FrozenOpenCLIPEmbedder": @@ -1557,17 +1374,17 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) except Exception as e: - print(f"Exception setting up output: {e}") + logger.error(f"Exception setting up output: {e}") pipe = None traceback.print_exc() if pipe is None or db_config is None: msg = "Pipeline or config is not set, unable to continue." - print(msg) + logger.error(msg) return "", "", 0, 0, "", "", "", "", image_size, "", msg else: resolution = db_config.resolution - print("Saving diffusion model...") + logger.info("Saving diffusion model...") pipe.save_pretrained(db_config.pretrained_model_name_or_path) result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}" model_dir = db_config.model_dir @@ -1576,12 +1393,12 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f src = db_config.src required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] if original_config_file is not None and os.path.exists(original_config_file): - logger.warn("copying original config: %s -> %s", original_config_file, db_config.model_dir) + logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir) shutil.copy(original_config_file, db_config.model_dir) basename = os.path.basename(original_config_file) new_ex_path = os.path.join(db_config.model_dir, basename) new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml") - logger.warn("copying model config to new name: %s -> %s", new_ex_path, new_name) + logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name) if os.path.exists(new_name): os.remove(new_name) os.rename(new_ex_path, new_name) @@ -1601,27 +1418,35 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f os.makedirs(rem_dir) - print(result_status) + logger.info(result_status) return -def convert_diffusion_original(ctx: ConversionContext, model_name: str, tensor_file: str, opset: int, half: bool): - model_path = os.path.join(ctx.model_path, model_name) - torch_name = model_name.replace("onnx", "torch") - torch_path = os.path.join(ctx.model_path, torch_name) - working_name = os.path.join(ctx.model_path, torch_name, "working") - logger.info("Converting original Diffusers checkpoint %s: %s -> %s", model_name, tensor_file, model_path) +def convert_diffusion_original( + ctx: ConversionContext, + model: ModelDict, + source: str, +): + name = model["name"] + source = source or model["source"] - if os.path.exists(model_path): + dest = os.path.join(ctx.model_path, name) + logger.info("Converting original Diffusers checkpoint %s: %s -> %s", name, source, dest) + + if os.path.exists(dest): logger.info("ONNX pipeline already exists, skipping.") return + torch_name = name + "-torch" + torch_path = os.path.join(ctx.cache_path, torch_name) + working_name = os.path.join(ctx.cache_path, torch_name, "working") + if os.path.exists(torch_path): logger.info("Torch pipeline already exists, reusing.") else: - logger.info("Converting original Diffusers check to Torch model: %s -> %s", tensor_file, torch_path) - extract_checkpoint(ctx, torch_name, tensor_file, from_hub=False) + logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path) + extract_checkpoint(ctx, torch_name, source, from_hub=False) logger.info("Converted original Diffusers checkpoint to Torch model.") - convert_diffusion_stable(ctx, model_path, working_name, opset, half, None) - logger.info("ONNX pipeline saved to %s", model_name) \ No newline at end of file + convert_diffusion_stable(ctx, model, working_name) + logger.info("ONNX pipeline saved to %s", name) diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index c52882ac..fd8a0ecd 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -1,19 +1,22 @@ +from logging import getLogger +from os import mkdir, path +from pathlib import Path +from shutil import rmtree +from typing import Dict + +import torch from diffusers import ( OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, ) -from torch.onnx import export -from logging import getLogger - -from shutil import rmtree -import torch -from os import path, mkdir -from pathlib import Path from onnx import load, save_model +from torch.onnx import export +from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( + OnnxStableDiffusionUpscalePipeline, +) from .utils import ConversionContext -from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline logger = getLogger(__name__) @@ -47,22 +50,22 @@ def onnx_export( @torch.no_grad() def convert_diffusion_stable( - ctx: ConversionContext, - name: str, - url: str, - opset: int, - half: bool, - token: str, - single_vae: bool = False, + ctx: ConversionContext, + model: Dict, + source: str, ): """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - dtype = torch.float16 if half else torch.float32 + name = model.get("name") + source = source or model.get("source") + single_vae = model.get("single_vae") + + dtype = torch.float16 if ctx.half else torch.float32 dest_path = path.join(ctx.model_path, name) # diffusers go into a directory rather than .onnx file - logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path) + logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path) if single_vae: logger.info("converting model with single VAE") @@ -71,13 +74,16 @@ def convert_diffusion_stable( logger.info("ONNX model already exists, skipping.") return - if half and ctx.training_device != "cuda": + if ctx.half and ctx.training_device != "cuda": raise ValueError( "Half precision model export is only supported on GPUs with CUDA" ) pipeline = StableDiffusionPipeline.from_pretrained( - url, torch_dtype=dtype, use_auth_token=token + source, + torch_dtype=dtype, + use_auth_token=ctx.token, + # cache_dir=path.join(ctx.cache_path, name) ).to(ctx.training_device) output_path = Path(dest_path) @@ -94,14 +100,16 @@ def convert_diffusion_stable( onnx_export( pipeline.text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files - model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)), + model_args=( + text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32) + ), output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], dynamic_axes={ "input_ids": {0: "batch", 1: "sequence"}, }, - opset=opset, + opset=ctx.opset, ) del pipeline.text_encoder @@ -113,7 +121,9 @@ def convert_diffusion_stable( unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int) else: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] - unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool) + unet_scale = torch.tensor(False).to( + device=ctx.training_device, dtype=torch.bool + ) unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size @@ -139,7 +149,7 @@ def convert_diffusion_stable( "timestep": {0: "batch"}, "encoder_hidden_states": {0: "batch", 1: "sequence"}, }, - opset=opset, + opset=ctx.opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) unet_model_path = str(unet_path.absolute().as_posix()) @@ -182,7 +192,7 @@ def convert_diffusion_stable( dynamic_axes={ "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, - opset=opset, + opset=ctx.opset, ) else: # VAE ENCODER @@ -207,7 +217,7 @@ def convert_diffusion_stable( dynamic_axes={ "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, - opset=opset, + opset=ctx.opset, ) # VAE DECODER @@ -230,7 +240,7 @@ def convert_diffusion_stable( dynamic_axes={ "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, - opset=opset, + opset=ctx.opset, ) del pipeline.vae @@ -261,7 +271,7 @@ def convert_diffusion_stable( "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "images": {0: "batch", 1: "height", 2: "width", 3: "channels"}, }, - opset=opset, + opset=ctx.opset, ) del pipeline.safety_checker safety_checker = OnnxRuntimeModel.from_pretrained( @@ -312,4 +322,3 @@ def convert_diffusion_stable( ) logger.info("ONNX pipeline is loadable") - diff --git a/api/onnx_web/convert/upscale_resrgan.py b/api/onnx_web/convert/upscale_resrgan.py index 19e6581c..51a95804 100644 --- a/api/onnx_web/convert/upscale_resrgan.py +++ b/api/onnx_web/convert/upscale_resrgan.py @@ -1,32 +1,34 @@ -import torch +from logging import getLogger +from os import path from shutil import copyfile + +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from torch.onnx import export -from os import path -from logging import getLogger -from basicsr.archs.rrdbnet_arch import RRDBNet -from .utils import ConversionContext + +from .utils import ConversionContext, ModelDict logger = getLogger(__name__) @torch.no_grad() -def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int): - dest_path = path.join(ctx.model_path, name + ".pth") - dest_onnx = path.join(ctx.model_path, name + ".onnx") - logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx) +def convert_upscale_resrgan( + ctx: ConversionContext, + model: ModelDict, + source: str, +): + name = model.get("name") + source = source or model.get("source") + scale = model.get("scale") - if path.isfile(dest_onnx): + dest = path.join(ctx.model_path, name + ".onnx") + logger.info("converting Real ESRGAN model: %s -> %s", name, dest) + + if path.isfile(dest): logger.info("ONNX model already exists, skipping.") return - if not path.isfile(dest_path): - logger.info("PTH model not found, downloading...") - download_path = load_file_from_url( - url=url, model_dir=dest_path + "-cache", progress=True, file_name=None - ) - copyfile(download_path, dest_path) - logger.info("loading and training model") model = RRDBNet( num_in_ch=3, @@ -37,7 +39,7 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: scale=scale, ) - torch_model = torch.load(dest_path, map_location=ctx.map_location) + torch_model = torch.load(source, map_location=ctx.map_location) if "params_ema" in torch_model: model.load_state_dict(torch_model["params_ema"]) else: @@ -54,15 +56,15 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: "output": {2: "width", 3: "height"}, } - logger.info("exporting ONNX model to %s", dest_onnx) + logger.info("exporting ONNX model to %s", dest) export( model, rng, - dest_onnx, + dest, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=opset, + opset_version=ctx.opset, export_params=True, ) logger.info("Real ESRGAN exported to ONNX successfully.") diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 5149c0bd..5a984900 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -1,7 +1,129 @@ +import shutil +from functools import partial +from logging import getLogger +from os import path +from pathlib import Path +from typing import Dict, Union, List, Optional, Tuple + +import requests import torch +from tqdm.auto import tqdm + +logger = getLogger(__name__) + + +ModelDict = Dict[str, Union[str, int]] +LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]] + class ConversionContext: - def __init__(self, model_path: str, device: str) -> None: - self.model_path = model_path - self.training_device = device - self.map_location = torch.device(device) \ No newline at end of file + def __init__( + self, + model_path: str, + device: str, + cache_path: Optional[str] = None, + half: Optional[bool] = False, + opset: Optional[int] = None, + token: Optional[str] = None, + ) -> None: + self.model_path = model_path + self.cache_path = cache_path or path.join(model_path, ".cache") + self.training_device = device + self.map_location = torch.device(device) + self.half = half + self.opset = opset + self.token = token + + +def download_progress(urls: List[Tuple[str, str]]): + for url, dest in urls: + dest_path = Path(dest).expanduser().resolve() + dest_path.parent.mkdir(parents=True, exist_ok=True) + + if dest_path.exists(): + logger.info("Destination already exists: %s", dest_path) + return str(dest_path.absolute()) + + req = requests.get(url, stream=True, allow_redirects=True) + if req.status_code != 200: + req.raise_for_status() # Only works for 4xx errors, per SO answer + raise RuntimeError( + "Request to %s failed with status code: %s" % (url, req.status_code) + ) + + total = int(req.headers.get("Content-Length", 0)) + desc = "unknown" if total == 0 else "" + req.raw.read = partial(req.raw.read, decode_content=True) + with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data: + with dest_path.open("wb") as f: + shutil.copyfileobj(data, f) + + return str(dest_path.absolute()) + + +def tuple_to_correction(model: Union[ModelDict, LegacyModel]): + if isinstance(model, list) or isinstance(model, tuple): + name, source, *rest = model + scale = rest[0] if len(rest) > 0 else 1 + half = rest[0] if len(rest) > 0 else False + opset = rest[0] if len(rest) > 0 else None + + return { + "name": name, + "source": source, + "half": half, + "opset": opset, + "scale": scale, + } + else: + return model + + +def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]): + if isinstance(model, list) or isinstance(model, tuple): + name, source, *rest = model + single_vae = rest[0] if len(rest) > 0 else False + half = rest[0] if len(rest) > 0 else False + opset = rest[0] if len(rest) > 0 else None + + return { + "name": name, + "source": source, + "half": half, + "opset": opset, + "single_vae": single_vae, + } + else: + return model + + +def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): + if isinstance(model, list) or isinstance(model, tuple): + name, source, *rest = model + scale = rest[0] if len(rest) > 0 else 1 + half = rest[0] if len(rest) > 0 else False + opset = rest[0] if len(rest) > 0 else None + + return { + "name": name, + "source": source, + "half": half, + "opset": opset, + "scale": scale, + } + else: + return model + + +known_formats = ["onnx", "pth", "ckpt", "safetensors"] + +def source_format(model: Dict) -> Optional[str]: + if "format" in model: + return model["format"] + + if "source" in model: + ext = path.splitext(model["source"]) + if ext in known_formats: + return ext + + return None diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index d4490e7a..58eeb1c1 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -649,10 +649,10 @@ def chain(): return error_reply("chain pipeline must have a body") data = yaml.safe_load(body) - with open("./schema.yaml", "r") as f: + with open("./schemas/chain.yaml", "r") as f: schema = yaml.safe_load(f.read()) - logger.info("validating chain request: %s against %s", data, schema) + logger.debug("validating chain request: %s against %s", data, schema) validate(data, schema) # get defaults from the regular parameters diff --git a/api/pyproject.toml b/api/pyproject.toml index efe56db3..04c7bca4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py''' +force-exclude = '''/(diffusion_original|lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py''' [tool.isort] profile = "black" diff --git a/api/schema.yaml b/api/schemas/chain.yaml similarity index 94% rename from api/schema.yaml rename to api/schemas/chain.yaml index 98052c82..1b41c503 100644 --- a/api/schema.yaml +++ b/api/schemas/chain.yaml @@ -1,4 +1,4 @@ -$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml +$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/chain.yaml $schema: https://json-schema.org/draft/2020-12/schema $defs: diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml new file mode 100644 index 00000000..56230cc4 --- /dev/null +++ b/api/schemas/extras.yaml @@ -0,0 +1,65 @@ +$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/extras.yaml +$schema: https://json-schema.org/draft/2020-12/schema + +$defs: + legacy_tuple: + type: array + items: + oneOf: + - type: string + - type: number + + base_model: + type: object + required: [name, source] + properties: + format: + type: string + enum: [onnx, pth, ckpt, safetensors] + half: + type: boolean + name: + type: string + opset: + type: number + source: + type: string + + correction_model: + allOf: + - $ref: "#/$defs/base_model" + + diffusion_model: + allOf: + - $ref: "#/$defs/base_model" + + upscaling_model: + allOf: + - $ref: "#/$defs/base_model" + - type: object + required: [scale] + properties: + scale: + type: number + +type: object +additionalProperties: False +properties: + diffusion: + type: array + items: + oneOf: + - $ref: "#/$defs/legacy_tuple" + - $ref: "#/$defs/diffusion_model" + correction: + type: array + items: + oneOf: + - $ref: "#/$defs/legacy_tuple" + - $ref: "#/$defs/correction_model" + upscaling: + type: array + items: + oneOf: + - $ref: "#/$defs/legacy_tuple" + - $ref: "#/$defs/upscaling_model" diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 78337b97..ebd137ae 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -58,6 +58,7 @@ "rocm", "RRDB", "runwayml", + "safetensors", "scandir", "scipy", "scrollback",