import warnings from argparse import ArgumentParser from logging import getLogger from os import makedirs, path from sys import exit from traceback import format_exception from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse from jsonschema import ValidationError, validate from yaml import safe_load from .correction_gfpgan import convert_correction_gfpgan from .diffusion.diffusers import convert_diffusion_diffusers from .diffusion.original import convert_diffusion_original from .diffusion.textual_inversion import convert_diffusion_textual_inversion from .upscale_resrgan import convert_upscale_resrgan from .utils import ( ConversionContext, download_progress, model_formats_original, remove_prefix, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_source, tuple_to_upscaling, ) # suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 warnings.filterwarnings( "ignore", ".*The shape inference of prim::Constant type is missing.*" ) warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*") warnings.filterwarnings( "ignore", ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ) Models = Dict[str, List[Any]] logger = getLogger(__name__) model_sources: Dict[str, Tuple[str, str]] = { "civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"), } model_source_huggingface = "huggingface://" # recommended models base_models: Models = { "diffusion": [ # v1.x ( "stable-diffusion-onnx-v1-5", model_source_huggingface + "runwayml/stable-diffusion-v1-5", ), ( "stable-diffusion-onnx-v1-inpainting", model_source_huggingface + "runwayml/stable-diffusion-inpainting", ), # v2.x ( "stable-diffusion-onnx-v2-1", model_source_huggingface + "stabilityai/stable-diffusion-2-1", ), ( "stable-diffusion-onnx-v2-inpainting", model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting", ), # TODO: should have its own converter ( "upscaling-stable-diffusion-x4", model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler", True, ), ], "correction": [ ( "correction-gfpgan-v1-3", "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", 4, ), ( "correction-codeformer", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth", 1, ), ], "upscaling": [ ( "upscaling-real-esrgan-x2-plus", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", 2, ), ( "upscaling-real-esrgan-x4-plus", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", 4, ), ( "upscaling-real-esrgan-x4-v3", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 4, ), ], # download only "sources": [ ( "detection-resnet50-final", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth", ), ( "detection-mobilenet025-final", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth", ), ( "detection-yolo-v5-l", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth", ), ( "detection-yolo-v5-n", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth", ), ( "parsing-bisenet", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth", ), ( "parsing-parsenet", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth", ), ], } def fetch_model( ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None ) -> str: cache_name = path.join(ctx.cache_path, name) model_path = path.join(ctx.model_path, name) model_onnx = model_path + ".onnx" for p in [model_path, model_onnx]: if path.exists(p): logger.debug("model already exists, skipping fetch") return p # add an extension if possible, some of the conversion code checks for it if model_format is None: url = urlparse(source) ext = path.basename(url.path) _filename, ext = path.splitext(ext) if ext is not None: cache_name += ext else: cache_name = "%s.%s" % (cache_name, model_format) for proto in model_sources: api_name, api_root = model_sources.get(proto) if source.startswith(proto): api_source = api_root % (remove_prefix(source, proto)) logger.info( "downloading model from %s: %s -> %s", api_name, api_source, cache_name ) return download_progress([(api_source, cache_name)]) if source.startswith(model_source_huggingface): hub_source = remove_prefix(source, model_source_huggingface) logger.info("downloading model from Huggingface Hub: %s", hub_source) # from_pretrained has a bunch of useful logic that snapshot_download by itself down not return hub_source elif source.startswith("https://"): logger.info("downloading model from: %s", source) return download_progress([(source, cache_name)]) elif source.startswith("http://"): logger.warning("downloading model from insecure source: %s", source) return download_progress([(source, cache_name)]) elif source.startswith(path.sep) or source.startswith("."): logger.info("using local model: %s", source) return source else: logger.info("unknown model location, using path as provided: %s", source) return source def convert_models(ctx: ConversionContext, args, models: Models): if args.sources and "sources" in models: for model in models.get("sources"): model = tuple_to_source(model) name = model.get("name") if name in args.skip: logger.info("skipping source: %s", name) else: model_format = source_format(model) source = model["source"] try: dest = fetch_model(ctx, name, source, model_format=model_format) logger.info("finished downloading source: %s -> %s", source, dest) except Exception as e: logger.error("error fetching source %s: %s", name, e) if args.diffusion and "diffusion" in models: for model in models.get("diffusion"): model = tuple_to_diffusion(model) name = model.get("name") if name in args.skip: logger.info("skipping model: %s", name) else: model_format = source_format(model) try: source = fetch_model( ctx, name, model["source"], model_format=model_format ) if model_format in model_formats_original: convert_diffusion_original( ctx, model, source, ) else: convert_diffusion_diffusers( ctx, model, source, ) for inversion in model.get("inversions", []): inversion_name = inversion["name"] inversion_source = inversion["source"] inversion_format = inversion.get("format", "huggingface") inversion_source = fetch_model( ctx, f"{name}-inversion-{inversion_name}", inversion_source ) convert_diffusion_textual_inversion( ctx, inversion_name, model["source"], inversion_source, inversion_format, ) except Exception as e: logger.error( "error converting diffusion model %s: %s", name, format_exception(type(e), e, e.__traceback__), ) if args.upscaling and "upscaling" in models: for model in models.get("upscaling"): model = tuple_to_upscaling(model) name = model.get("name") if name in args.skip: logger.info("skipping model: %s", name) else: model_format = source_format(model) try: source = fetch_model( ctx, name, model["source"], model_format=model_format ) convert_upscale_resrgan(ctx, model, source) except Exception as e: logger.error( "error converting upscaling model %s: %s", name, format_exception(type(e), e, e.__traceback__), ) if args.correction and "correction" in models: for model in models.get("correction"): model = tuple_to_correction(model) name = model.get("name") if name in args.skip: logger.info("skipping model: %s", name) else: model_format = source_format(model) try: source = fetch_model( ctx, name, model["source"], model_format=model_format ) convert_correction_gfpgan(ctx, model, source) except Exception as e: logger.error( "error converting correction model %s: %s", name, format_exception(type(e), e, e.__traceback__), ) def main() -> int: parser = ArgumentParser( prog="onnx-web model converter", description="convert checkpoint models to ONNX" ) # model groups parser.add_argument("--sources", action="store_true", default=False) parser.add_argument("--correction", action="store_true", default=False) parser.add_argument("--diffusion", action="store_true", default=False) parser.add_argument("--upscaling", action="store_true", default=False) # extra models parser.add_argument("--extras", nargs="*", type=str, default=[]) parser.add_argument("--skip", nargs="*", type=str, default=[]) # export options parser.add_argument( "--half", action="store_true", default=False, help="Export models for half precision, faster on some Nvidia cards.", ) parser.add_argument( "--opset", default=14, type=int, help="The version of the ONNX operator set to use.", ) parser.add_argument( "--token", type=str, help="HuggingFace token with read permissions for downloading models.", ) args = parser.parse_args() logger.info("CLI arguments: %s", args) ctx = ConversionContext.from_environ() ctx.half = args.half ctx.opset = args.opset ctx.token = args.token logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device) if ctx.half and ctx.training_device != "cuda": raise ValueError( "half precision model export is only supported on GPUs with CUDA" ) if not path.exists(ctx.model_path): logger.info("model path does not existing, creating: %s", ctx.model_path) makedirs(ctx.model_path) logger.info("converting base models") convert_models(ctx, args, base_models) for file in args.extras: if file is not None and file != "": logger.info("loading extra models from %s", file) try: with open(file, "r") as f: data = safe_load(f.read()) with open("./schemas/extras.yaml", "r") as f: schema = safe_load(f.read()) logger.debug("validating chain request: %s against %s", data, schema) try: validate(data, schema) logger.info("converting extra models") convert_models(ctx, args, data) except ValidationError as err: logger.error("invalid data in extras file: %s", err) except Exception as err: logger.error("error converting extra models: %s", err) return 0 if __name__ == "__main__": exit(main())