diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 38188e97..3e058362 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -1,20 +1,26 @@ import warnings from argparse import ArgumentParser -from json import loads from logging import getLogger from os import environ, makedirs, path from sys import exit from typing import Dict, List, Optional, Tuple -from yaml import safe_load -from jsonschema import validate, ValidationError import torch +from jsonschema import ValidationError, validate +from yaml import safe_load from .correction_gfpgan import convert_correction_gfpgan from .diffusion_original import convert_diffusion_original from .diffusion_stable import convert_diffusion_stable from .upscale_resrgan import convert_upscale_resrgan -from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling +from .utils import ( + ConversionContext, + download_progress, + source_format, + tuple_to_correction, + tuple_to_diffusion, + tuple_to_upscaling, +) # suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 warnings.filterwarnings( @@ -100,7 +106,9 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")) training_device = "cuda" if torch.cuda.is_available() else "cpu" -def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str: +def fetch_model( + ctx: ConversionContext, name: str, source: str, format: Optional[str] = None +) -> str: cache_name = path.join(ctx.cache_path, name) if format is not None: # add an extension if possible, some of the conversion code checks for it @@ -110,7 +118,9 @@ def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional api_name, api_root = model_sources.get(proto) if source.startswith(proto): api_source = api_root % (source.removeprefix(proto)) - logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name) + logger.info( + "Downloading model from %s: %s -> %s", api_name, api_source, cache_name + ) return download_progress([(api_source, cache_name)]) if source.startswith(model_source_huggingface): @@ -218,7 +228,9 @@ def main() -> int: args = parser.parse_args() logger.info("CLI arguments: %s", args) - ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token) + ctx = ConversionContext( + model_path, training_device, half=args.half, opset=args.opset, token=args.token + ) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) if not path.exists(model_path): diff --git a/api/onnx_web/convert/correction_gfpgan.py b/api/onnx_web/convert/correction_gfpgan.py index d0ec6d18..35830328 100644 --- a/api/onnx_web/convert/correction_gfpgan.py +++ b/api/onnx_web/convert/correction_gfpgan.py @@ -1,10 +1,8 @@ from logging import getLogger from os import path -from shutil import copyfile import torch from basicsr.archs.rrdbnet_arch import RRDBNet -from basicsr.utils.download_util import load_file_from_url from torch.onnx import export from .utils import ConversionContext, ModelDict diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 470ce847..35c9a4a2 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -959,6 +959,7 @@ def replace_symlinks(path, base): for subpath in os.listdir(path): replace_symlinks(os.path.join(path, subpath), base) + def download_model(db_config: DreamboothConfig, token): tmp_dir = os.path.join(db_config.model_dir, "src") @@ -994,7 +995,7 @@ def download_model(db_config: DreamboothConfig, token): if "model_index.json" in name: model_index = name continue - if (".ckpt" in name or ".safetensors" in name) and not "/" in name: + if (".ckpt" in name or ".safetensors" in name) and "/" not in name: model_files.append(name) continue for diffusion_dir in diffusion_dirs: @@ -1034,8 +1035,6 @@ def download_model(db_config: DreamboothConfig, token): logger.debug("Nothing to fetch!") return None, None - - # huggingface_hub.utils.tqdm.tqdm = mytqdm mytqdm = huggingface_hub.utils.tqdm.tqdm out_model = None for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"): @@ -1058,7 +1057,7 @@ def download_model(db_config: DreamboothConfig, token): for diffusion_dir in diffusion_dirs: if diffusion_dir in out: out_model = db_config.pretrained_model_name_or_path - dest = os.path.join(db_config.pretrained_model_name_or_path,diffusion_dir) + dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir) if not dest: if ".ckpt" in out or ".safetensors" in out: dest = os.path.join(db_config.model_dir, "src") @@ -1074,12 +1073,13 @@ def download_model(db_config: DreamboothConfig, token): return out_model, config_file + def get_config_path( - model_version: str = "v1", - train_type: str = "default", - config_base_name: str = "training", - prediction_type: str = "epsilon" - ): + model_version: str = "v1", + train_type: str = "default", + config_base_name: str = "training", + prediction_type: str = "epsilon" +): train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v" parts = os.path.join( @@ -1093,6 +1093,7 @@ def get_config_path( ) return os.path.abspath(parts) + def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"): config_base_name = "training" @@ -1117,8 +1118,18 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"): return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type) -def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="", - new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True): +def extract_checkpoint( + ctx: ConversionContext, + new_model_name: str, + checkpoint_file: str, + scheduler_type="ddim", + from_hub=False, + new_model_url="", + new_model_token="", + extract_ema=False, + train_unfrozen=False, + is_512=True, +): """ @param new_model_name: The name of the new model @@ -1153,9 +1164,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f msg = None if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""): - msg = "Please provide a URL and token for huggingface models." - if msg is not None: - return "", "", 0, 0, "", "", "", "", image_size, "", msg + logger.warning("Please provide a URL and token for huggingface models.") + return # Create empty config db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, @@ -1178,7 +1188,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f else: msg = "Unable to fetch model from hub." logger.warning(msg) - return "", "", 0, 0, "", "", "", "", image_size, "", msg + return try: checkpoint = None @@ -1194,6 +1204,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f logger.debug("Loading safetensors...") checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu") except Exception as e: + logger.warn("Failed to load as safetensors file, falling back to torch...", e) checkpoint = torch.jit.load(checkpoint_file) else: logger.debug("Loading ckpt...") @@ -1231,10 +1242,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024: logger.debug("UNet using v2 parameters.") v2 = True - - except: - logger.error("Exception loading unet!") - traceback.print_exc() + except Exception as e: + logger.error("Exception loading unet!", traceback.format_exception(e)) if v2 and not is_512: prediction_type = "v_prediction" @@ -1263,7 +1272,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f if original_config_file is None or not os.path.exists(original_config_file): logger.warning("Unable to select a config file: %s" % (original_config_file)) - return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file." + return logger.debug(f"Trying to load: {original_config_file}") original_config = OmegaConf.load(original_config_file) @@ -1303,9 +1312,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f else: raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - logger.info("Converting UNet...") # Convert the UNet2DConditionModel model. + logger.info("Converting UNet...") unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["upcast_attention"] = upcast_attention unet = UNet2DConditionModel(**unet_config) @@ -1317,16 +1325,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f db_config.save() unet.load_state_dict(converted_unet_checkpoint) - logger.info("Converting VAE...") # Convert the VAE model. + logger.info("Converting VAE...") vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) - logger.info("Converting text encoder...") # Convert the text model. + logger.info("Converting text encoder...") text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] if text_model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) @@ -1374,23 +1382,19 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) except Exception as e: - logger.error(f"Exception setting up output: {e}") + logger.error("Exception setting up output: %s", traceback.format_exception(e)) pipe = None - traceback.print_exc() if pipe is None or db_config is None: msg = "Pipeline or config is not set, unable to continue." logger.error(msg) - return "", "", 0, 0, "", "", "", "", image_size, "", msg + return else: - resolution = db_config.resolution logger.info("Saving diffusion model...") pipe.save_pretrained(db_config.pretrained_model_name_or_path) result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}" - model_dir = db_config.model_dir revision = db_config.revision scheduler = db_config.scheduler - src = db_config.src required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] if original_config_file is not None and os.path.exists(original_config_file): logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir) @@ -1417,11 +1421,11 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f if not os.path.exists(rem_dir): os.makedirs(rem_dir) - logger.info(result_status) return + def convert_diffusion_original( ctx: ConversionContext, model: ModelDict, diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index fd8a0ecd..9afcddf9 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -65,7 +65,9 @@ def convert_diffusion_stable( dest_path = path.join(ctx.model_path, name) # diffusers go into a directory rather than .onnx file - logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path) + logger.info( + "converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path + ) if single_vae: logger.info("converting model with single VAE") diff --git a/api/onnx_web/convert/upscale_resrgan.py b/api/onnx_web/convert/upscale_resrgan.py index 51a95804..88d05ce4 100644 --- a/api/onnx_web/convert/upscale_resrgan.py +++ b/api/onnx_web/convert/upscale_resrgan.py @@ -1,10 +1,8 @@ from logging import getLogger from os import path -from shutil import copyfile import torch from basicsr.archs.rrdbnet_arch import RRDBNet -from basicsr.utils.download_util import load_file_from_url from torch.onnx import export from .utils import ConversionContext, ModelDict diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 5a984900..141481ee 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -3,7 +3,7 @@ from functools import partial from logging import getLogger from os import path from pathlib import Path -from typing import Dict, Union, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import requests import torch @@ -44,7 +44,14 @@ def download_progress(urls: List[Tuple[str, str]]): logger.info("Destination already exists: %s", dest_path) return str(dest_path.absolute()) - req = requests.get(url, stream=True, allow_redirects=True) + req = requests.get( + url, + stream=True, + allow_redirects=True, + headers={ + "User-Agent": "onnx-web-api", + }, + ) if req.status_code != 200: req.raise_for_status() # Only works for 4xx errors, per SO answer raise RuntimeError( @@ -117,6 +124,7 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): known_formats = ["onnx", "pth", "ckpt", "safetensors"] + def source_format(model: Dict) -> Optional[str]: if "format" in model: return model["format"]