1
0
Fork 0
onnx-web/api/onnx_web/convert/__main__.py

381 lines
13 KiB
Python
Raw Normal View History

import warnings
from argparse import ArgumentParser
from logging import getLogger
from os import makedirs, path
from sys import exit
2023-02-25 14:22:12 +00:00
from traceback import format_exception
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from jsonschema import ValidationError, validate
from yaml import safe_load
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion.diffusers import convert_diffusion_diffusers
2023-02-22 05:50:27 +00:00
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
from .upscale_resrgan import convert_upscale_resrgan
from .utils import (
ConversionContext,
download_progress,
2023-02-11 20:19:42 +00:00
model_formats_original,
remove_prefix,
source_format,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_source,
tuple_to_upscaling,
)
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings(
"ignore", ".*The shape inference of prim::Constant type is missing.*"
)
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
warnings.filterwarnings(
"ignore",
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Any]]
logger = getLogger(__name__)
model_sources: Dict[str, Tuple[str, str]] = {
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
}
model_source_huggingface = "huggingface://"
# recommended models
base_models: Models = {
"diffusion": [
# v1.x
(
"stable-diffusion-onnx-v1-5",
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
),
(
"stable-diffusion-onnx-v1-inpainting",
model_source_huggingface + "runwayml/stable-diffusion-inpainting",
),
# v2.x
(
"stable-diffusion-onnx-v2-1",
model_source_huggingface + "stabilityai/stable-diffusion-2-1",
),
(
"stable-diffusion-onnx-v2-inpainting",
model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting",
),
# TODO: should have its own converter
(
"upscaling-stable-diffusion-x4",
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
True,
),
],
"correction": [
(
"correction-gfpgan-v1-3",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
4,
),
(
"correction-codeformer",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
1,
),
],
"upscaling": [
(
"upscaling-real-esrgan-x2-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
2,
),
(
"upscaling-real-esrgan-x4-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
4,
),
(
"upscaling-real-esrgan-x4-v3",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
4,
),
],
# download only
"sources": [
(
"detection-resnet50-final",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
),
(
"detection-mobilenet025-final",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth",
),
(
"detection-yolo-v5-l",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth",
),
(
"detection-yolo-v5-n",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth",
),
(
"parsing-bisenet",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth",
),
(
"parsing-parsenet",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
),
],
}
def fetch_model(
2023-02-11 20:19:42 +00:00
ctx: ConversionContext, name: str, source: str, model_format: Optional[str] = None
) -> str:
cache_name = path.join(ctx.cache_path, name)
model_path = path.join(ctx.model_path, name)
model_onnx = model_path + ".onnx"
for p in [model_path, model_onnx]:
if path.exists(p):
2023-02-17 00:42:05 +00:00
logger.debug("model already exists, skipping fetch")
return p
# add an extension if possible, some of the conversion code checks for it
if model_format is None:
url = urlparse(source)
ext = path.basename(url.path)
2023-02-19 13:53:20 +00:00
_filename, ext = path.splitext(ext)
if ext is not None:
cache_name += ext
else:
2023-02-11 20:19:42 +00:00
cache_name = "%s.%s" % (cache_name, model_format)
for proto in model_sources:
api_name, api_root = model_sources.get(proto)
if source.startswith(proto):
api_source = api_root % (remove_prefix(source, proto))
logger.info(
2023-02-17 00:42:05 +00:00
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
)
return download_progress([(api_source, cache_name)])
if source.startswith(model_source_huggingface):
hub_source = remove_prefix(source, model_source_huggingface)
2023-02-17 00:42:05 +00:00
logger.info("downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
return hub_source
elif source.startswith("https://"):
2023-02-17 00:42:05 +00:00
logger.info("downloading model from: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith("http://"):
2023-02-17 00:42:05 +00:00
logger.warning("downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith(path.sep) or source.startswith("."):
2023-02-17 00:42:05 +00:00
logger.info("using local model: %s", source)
return source
else:
2023-02-17 00:42:05 +00:00
logger.info("unknown model location, using path as provided: %s", source)
return source
def convert_models(ctx: ConversionContext, args, models: Models):
if args.sources and "sources" in models:
for model in models.get("sources"):
model = tuple_to_source(model)
name = model.get("name")
if name in args.skip:
2023-02-17 00:42:05 +00:00
logger.info("skipping source: %s", name)
else:
model_format = source_format(model)
source = model["source"]
try:
dest = fetch_model(ctx, name, source, model_format=model_format)
logger.info("finished downloading source: %s -> %s", source, dest)
except Exception as e:
logger.error("error fetching source %s: %s", name, e)
if args.diffusion and "diffusion" in models:
for model in models.get("diffusion"):
model = tuple_to_diffusion(model)
name = model.get("name")
if name in args.skip:
2023-02-17 00:42:05 +00:00
logger.info("skipping model: %s", name)
else:
2023-02-11 20:19:42 +00:00
model_format = source_format(model)
try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
if model_format in model_formats_original:
convert_diffusion_original(
ctx,
model,
source,
)
else:
convert_diffusion_diffusers(
ctx,
model,
source,
)
2023-02-22 03:40:57 +00:00
for inversion in model.get("inversions", []):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", "huggingface")
2023-02-22 05:50:27 +00:00
inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source
)
convert_diffusion_textual_inversion(
ctx,
inversion_name,
model["source"],
inversion_source,
inversion_format,
2023-02-22 05:50:27 +00:00
)
2023-02-22 03:40:57 +00:00
except Exception as e:
logger.error(
"error converting diffusion model %s: %s",
name,
format_exception(type(e), e, e.__traceback__),
)
if args.upscaling and "upscaling" in models:
for model in models.get("upscaling"):
model = tuple_to_upscaling(model)
name = model.get("name")
if name in args.skip:
2023-02-17 00:42:05 +00:00
logger.info("skipping model: %s", name)
else:
2023-02-11 20:19:42 +00:00
model_format = source_format(model)
try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_upscale_resrgan(ctx, model, source)
except Exception as e:
logger.error(
"error converting upscaling model %s: %s",
name,
format_exception(type(e), e, e.__traceback__),
)
if args.correction and "correction" in models:
for model in models.get("correction"):
model = tuple_to_correction(model)
name = model.get("name")
if name in args.skip:
2023-02-17 00:42:05 +00:00
logger.info("skipping model: %s", name)
else:
2023-02-11 20:19:42 +00:00
model_format = source_format(model)
try:
source = fetch_model(
ctx, name, model["source"], model_format=model_format
)
convert_correction_gfpgan(ctx, model, source)
except Exception as e:
logger.error(
"error converting correction model %s: %s",
name,
format_exception(type(e), e, e.__traceback__),
)
2023-02-18 04:49:13 +00:00
def main() -> int:
parser = ArgumentParser(
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
)
# model groups
parser.add_argument("--sources", action="store_true", default=False)
parser.add_argument("--correction", action="store_true", default=False)
parser.add_argument("--diffusion", action="store_true", default=False)
parser.add_argument("--upscaling", action="store_true", default=False)
# extra models
parser.add_argument("--extras", nargs="*", type=str, default=[])
parser.add_argument("--skip", nargs="*", type=str, default=[])
# export options
parser.add_argument(
"--half",
action="store_true",
default=False,
help="Export models for half precision, faster on some Nvidia cards.",
)
parser.add_argument(
"--opset",
default=14,
type=int,
help="The version of the ONNX operator set to use.",
)
parser.add_argument(
"--token",
type=str,
help="HuggingFace token with read permissions for downloading models.",
)
args = parser.parse_args()
logger.info("CLI arguments: %s", args)
ctx = ConversionContext.from_environ()
ctx.half = args.half
ctx.opset = args.opset
ctx.token = args.token
2023-02-17 00:42:05 +00:00
logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device)
if ctx.half and ctx.training_device != "cuda":
raise ValueError(
2023-02-17 00:42:05 +00:00
"half precision model export is only supported on GPUs with CUDA"
)
if not path.exists(ctx.model_path):
2023-02-17 00:42:05 +00:00
logger.info("model path does not existing, creating: %s", ctx.model_path)
makedirs(ctx.model_path)
2023-02-17 00:42:05 +00:00
logger.info("converting base models")
convert_models(ctx, args, base_models)
for file in args.extras:
if file is not None and file != "":
2023-02-17 00:42:05 +00:00
logger.info("loading extra models from %s", file)
try:
with open(file, "r") as f:
data = safe_load(f.read())
with open("./schemas/extras.yaml", "r") as f:
schema = safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
try:
validate(data, schema)
2023-02-17 00:42:05 +00:00
logger.info("converting extra models")
convert_models(ctx, args, data)
except ValidationError as err:
2023-02-17 00:42:05 +00:00
logger.error("invalid data in extras file: %s", err)
except Exception as err:
2023-02-17 00:42:05 +00:00
logger.error("error converting extra models: %s", err)
return 0
if __name__ == "__main__":
exit(main())