2023-02-09 04:35:54 +00:00
|
|
|
import warnings
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
from logging import getLogger
|
2023-02-11 18:36:54 +00:00
|
|
|
from os import makedirs, path
|
2023-02-09 04:35:54 +00:00
|
|
|
from sys import exit
|
2023-03-01 04:30:29 +00:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
2023-02-12 12:25:44 +00:00
|
|
|
from urllib.parse import urlparse
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-03-20 01:16:52 +00:00
|
|
|
from huggingface_hub.file_download import hf_hub_download
|
2023-02-11 05:32:16 +00:00
|
|
|
from jsonschema import ValidationError, validate
|
2023-03-18 12:01:16 +00:00
|
|
|
from onnx import load_model, save_model
|
|
|
|
from transformers import CLIPTokenizer
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-03-24 13:14:19 +00:00
|
|
|
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
2023-05-04 03:15:17 +00:00
|
|
|
from ..utils import load_config
|
2023-04-10 22:49:56 +00:00
|
|
|
from .correction.gfpgan import convert_correction_gfpgan
|
2023-04-20 12:36:31 +00:00
|
|
|
from .diffusion.control import convert_diffusion_control
|
2023-03-05 05:03:15 +00:00
|
|
|
from .diffusion.diffusers import convert_diffusion_diffusers
|
2023-03-18 12:01:16 +00:00
|
|
|
from .diffusion.lora import blend_loras
|
|
|
|
from .diffusion.textual_inversion import blend_textual_inversions
|
2023-04-10 22:49:56 +00:00
|
|
|
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
|
|
|
from .upscaling.resrgan import convert_upscale_resrgan
|
|
|
|
from .upscaling.swinir import convert_upscaling_swinir
|
2023-02-11 05:32:16 +00:00
|
|
|
from .utils import (
|
2023-05-15 01:04:43 +00:00
|
|
|
DEFAULT_OPSET,
|
2023-02-11 05:32:16 +00:00
|
|
|
ConversionContext,
|
|
|
|
download_progress,
|
2023-02-15 03:23:16 +00:00
|
|
|
remove_prefix,
|
2023-02-11 05:32:16 +00:00
|
|
|
source_format,
|
|
|
|
tuple_to_correction,
|
|
|
|
tuple_to_diffusion,
|
2023-02-12 15:28:37 +00:00
|
|
|
tuple_to_source,
|
2023-02-11 05:32:16 +00:00
|
|
|
tuple_to_upscaling,
|
|
|
|
)
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-02-09 04:35:54 +00:00
|
|
|
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
|
|
|
warnings.filterwarnings(
|
|
|
|
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
|
|
|
)
|
|
|
|
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
|
|
|
|
warnings.filterwarnings(
|
|
|
|
"ignore",
|
|
|
|
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
|
|
|
)
|
|
|
|
|
2023-03-01 03:44:52 +00:00
|
|
|
Models = Dict[str, List[Any]]
|
2023-02-09 04:35:54 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
model_sources: Dict[str, Tuple[str, str]] = {
|
|
|
|
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
|
|
|
|
}
|
|
|
|
|
|
|
|
model_source_huggingface = "huggingface://"
|
|
|
|
|
2023-02-09 04:35:54 +00:00
|
|
|
# recommended models
|
|
|
|
base_models: Models = {
|
|
|
|
"diffusion": [
|
|
|
|
# v1.x
|
|
|
|
(
|
2023-02-11 04:41:24 +00:00
|
|
|
"stable-diffusion-onnx-v1-5",
|
|
|
|
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
|
2023-02-09 04:35:54 +00:00
|
|
|
),
|
2023-02-11 05:06:30 +00:00
|
|
|
(
|
|
|
|
"stable-diffusion-onnx-v1-inpainting",
|
|
|
|
model_source_huggingface + "runwayml/stable-diffusion-inpainting",
|
|
|
|
),
|
2023-02-11 04:41:24 +00:00
|
|
|
(
|
|
|
|
"upscaling-stable-diffusion-x4",
|
|
|
|
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
|
|
|
|
True,
|
|
|
|
),
|
2023-02-09 04:35:54 +00:00
|
|
|
],
|
|
|
|
"correction": [
|
|
|
|
(
|
|
|
|
"correction-gfpgan-v1-3",
|
|
|
|
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
|
|
|
4,
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"correction-codeformer",
|
|
|
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
|
|
|
|
1,
|
|
|
|
),
|
|
|
|
],
|
|
|
|
"upscaling": [
|
|
|
|
(
|
|
|
|
"upscaling-real-esrgan-x2-plus",
|
|
|
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
|
|
|
2,
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"upscaling-real-esrgan-x4-plus",
|
|
|
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
|
|
|
4,
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"upscaling-real-esrgan-x4-v3",
|
|
|
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
|
|
|
4,
|
|
|
|
),
|
2023-04-10 22:49:56 +00:00
|
|
|
{
|
|
|
|
"model": "swinir",
|
|
|
|
"name": "upscaling-swinir-classical-x4",
|
|
|
|
"source": "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth",
|
|
|
|
"scale": 4,
|
|
|
|
},
|
2023-04-11 00:04:52 +00:00
|
|
|
{
|
|
|
|
"model": "swinir",
|
2023-04-11 00:13:49 +00:00
|
|
|
"name": "upscaling-swinir-real-large-x4",
|
2023-04-11 00:04:52 +00:00
|
|
|
"source": "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth",
|
|
|
|
"scale": 4,
|
|
|
|
},
|
2023-04-10 22:49:56 +00:00
|
|
|
{
|
|
|
|
"model": "bsrgan",
|
|
|
|
"name": "upscaling-bsrgan-x4",
|
|
|
|
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
|
|
|
|
"scale": 4,
|
|
|
|
},
|
2023-04-11 00:02:12 +00:00
|
|
|
{
|
|
|
|
"model": "bsrgan",
|
|
|
|
"name": "upscaling-bsrgan-x2",
|
|
|
|
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGANx2.pth",
|
|
|
|
"scale": 2,
|
|
|
|
},
|
2023-02-09 04:35:54 +00:00
|
|
|
],
|
2023-02-12 15:28:37 +00:00
|
|
|
# download only
|
|
|
|
"sources": [
|
2023-04-15 21:55:53 +00:00
|
|
|
# CodeFormer: no ONNX yet
|
2023-02-12 15:28:37 +00:00
|
|
|
(
|
|
|
|
"detection-resnet50-final",
|
|
|
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"detection-mobilenet025-final",
|
|
|
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth",
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"detection-yolo-v5-l",
|
|
|
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth",
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"detection-yolo-v5-n",
|
|
|
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth",
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"parsing-bisenet",
|
|
|
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth",
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"parsing-parsenet",
|
|
|
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
|
|
|
|
),
|
2023-04-15 21:55:53 +00:00
|
|
|
# ControlNets: already converted
|
|
|
|
{
|
|
|
|
"dest": "control",
|
|
|
|
"format": "onnx",
|
|
|
|
"name": "canny",
|
|
|
|
"source": "https://huggingface.co/ForserX/sd-controlnet-canny-onnx/resolve/main/model.onnx",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"dest": "control",
|
|
|
|
"format": "onnx",
|
|
|
|
"name": "depth",
|
|
|
|
"source": "https://huggingface.co/ForserX/sd-controlnet-depth-onnx/resolve/main/model.onnx",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"dest": "control",
|
|
|
|
"format": "onnx",
|
|
|
|
"name": "hed",
|
|
|
|
"source": "https://huggingface.co/ForserX/sd-controlnet-hed-onnx/resolve/main/model.onnx",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"dest": "control",
|
|
|
|
"format": "onnx",
|
|
|
|
"name": "mlsd",
|
|
|
|
"source": "https://huggingface.co/ForserX/sd-controlnet-mlsd-onnx/resolve/main/model.onnx",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"dest": "control",
|
|
|
|
"format": "onnx",
|
|
|
|
"name": "normal",
|
|
|
|
"source": "https://huggingface.co/ForserX/sd-controlnet-normal-onnx/resolve/main/model.onnx",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"dest": "control",
|
|
|
|
"format": "onnx",
|
|
|
|
"name": "openpose",
|
|
|
|
"source": "https://huggingface.co/ForserX/sd-controlnet-openpose-onnx/resolve/main/model.onnx",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"dest": "control",
|
|
|
|
"format": "onnx",
|
|
|
|
"name": "seg",
|
|
|
|
"source": "https://huggingface.co/ForserX/sd-controlnet-seg-onnx/resolve/main/model.onnx",
|
|
|
|
},
|
2023-02-12 15:28:37 +00:00
|
|
|
],
|
2023-02-09 04:35:54 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2023-02-11 05:32:16 +00:00
|
|
|
def fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion: ConversionContext,
|
2023-03-18 12:07:05 +00:00
|
|
|
name: str,
|
|
|
|
source: str,
|
|
|
|
dest: Optional[str] = None,
|
|
|
|
format: Optional[str] = None,
|
2023-03-20 01:32:21 +00:00
|
|
|
hf_hub_fetch: bool = False,
|
|
|
|
hf_hub_filename: Optional[str] = None,
|
2023-05-04 03:21:04 +00:00
|
|
|
) -> Tuple[str, bool]:
|
2023-04-10 01:33:03 +00:00
|
|
|
cache_path = dest or conversion.cache_path
|
2023-03-20 01:32:21 +00:00
|
|
|
cache_name = path.join(cache_path, name)
|
2023-02-12 12:25:44 +00:00
|
|
|
|
|
|
|
# add an extension if possible, some of the conversion code checks for it
|
2023-03-18 12:07:05 +00:00
|
|
|
if format is None:
|
2023-02-12 12:25:44 +00:00
|
|
|
url = urlparse(source)
|
|
|
|
ext = path.basename(url.path)
|
2023-02-19 13:53:20 +00:00
|
|
|
_filename, ext = path.splitext(ext)
|
2023-02-12 12:25:44 +00:00
|
|
|
if ext is not None:
|
2023-03-20 01:32:21 +00:00
|
|
|
cache_name = cache_name + ext
|
2023-02-12 12:25:44 +00:00
|
|
|
else:
|
2023-03-20 01:32:21 +00:00
|
|
|
cache_name = f"{cache_name}.{format}"
|
2023-03-08 05:51:27 +00:00
|
|
|
|
|
|
|
if path.exists(cache_name):
|
|
|
|
logger.debug("model already exists in cache, skipping fetch")
|
2023-05-04 03:21:04 +00:00
|
|
|
return cache_name, False
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
for proto in model_sources:
|
|
|
|
api_name, api_root = model_sources.get(proto)
|
|
|
|
if source.startswith(proto):
|
2023-02-15 03:23:16 +00:00
|
|
|
api_source = api_root % (remove_prefix(source, proto))
|
2023-02-11 05:32:16 +00:00
|
|
|
logger.info(
|
2023-02-17 00:42:05 +00:00
|
|
|
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
|
2023-02-11 05:32:16 +00:00
|
|
|
)
|
2023-05-04 03:21:04 +00:00
|
|
|
return download_progress([(api_source, cache_name)]), False
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
if source.startswith(model_source_huggingface):
|
2023-02-15 03:23:16 +00:00
|
|
|
hub_source = remove_prefix(source, model_source_huggingface)
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
2023-02-11 04:41:24 +00:00
|
|
|
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
2023-03-20 01:32:21 +00:00
|
|
|
if hf_hub_fetch:
|
2023-05-04 03:21:04 +00:00
|
|
|
return (
|
|
|
|
hf_hub_download(
|
|
|
|
repo_id=hub_source,
|
|
|
|
filename=hf_hub_filename,
|
|
|
|
cache_dir=cache_path,
|
|
|
|
force_filename=f"{name}.bin",
|
|
|
|
),
|
|
|
|
False,
|
2023-03-20 01:32:21 +00:00
|
|
|
)
|
|
|
|
else:
|
2023-05-04 03:21:04 +00:00
|
|
|
return hub_source, True
|
2023-02-11 04:41:24 +00:00
|
|
|
elif source.startswith("https://"):
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("downloading model from: %s", source)
|
2023-05-04 03:21:04 +00:00
|
|
|
return download_progress([(source, cache_name)]), False
|
2023-02-11 04:41:24 +00:00
|
|
|
elif source.startswith("http://"):
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.warning("downloading model from insecure source: %s", source)
|
2023-05-04 03:21:04 +00:00
|
|
|
return download_progress([(source, cache_name)]), False
|
2023-02-11 04:41:24 +00:00
|
|
|
elif source.startswith(path.sep) or source.startswith("."):
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("using local model: %s", source)
|
2023-05-04 03:21:04 +00:00
|
|
|
return source, False
|
2023-02-11 04:41:24 +00:00
|
|
|
else:
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("unknown model location, using path as provided: %s", source)
|
2023-05-04 03:21:04 +00:00
|
|
|
return source, False
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
|
2023-04-10 01:33:03 +00:00
|
|
|
def convert_models(conversion: ConversionContext, args, models: Models):
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors = []
|
|
|
|
|
2023-02-12 18:23:03 +00:00
|
|
|
if args.sources and "sources" in models:
|
2023-02-12 15:28:37 +00:00
|
|
|
for model in models.get("sources"):
|
|
|
|
model = tuple_to_source(model)
|
|
|
|
name = model.get("name")
|
|
|
|
|
|
|
|
if name in args.skip:
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("skipping source: %s", name)
|
2023-02-12 15:28:37 +00:00
|
|
|
else:
|
|
|
|
model_format = source_format(model)
|
|
|
|
source = model["source"]
|
2023-02-17 13:49:45 +00:00
|
|
|
|
|
|
|
try:
|
2023-04-15 21:55:53 +00:00
|
|
|
dest_path = None
|
|
|
|
if "dest" in model:
|
|
|
|
dest_path = path.join(conversion.model_path, model["dest"])
|
|
|
|
|
2023-05-04 03:21:04 +00:00
|
|
|
dest, hf = fetch_model(
|
2023-04-15 21:55:53 +00:00
|
|
|
conversion, name, source, format=model_format, dest=dest_path
|
|
|
|
)
|
2023-02-17 13:49:45 +00:00
|
|
|
logger.info("finished downloading source: %s -> %s", source, dest)
|
2023-03-17 03:29:07 +00:00
|
|
|
except Exception:
|
|
|
|
logger.exception("error fetching source %s", name)
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors.append(name)
|
2023-02-12 15:28:37 +00:00
|
|
|
|
2023-03-18 12:40:57 +00:00
|
|
|
if args.networks and "networks" in models:
|
|
|
|
for network in models.get("networks"):
|
|
|
|
name = network["name"]
|
|
|
|
|
|
|
|
if name in args.skip:
|
|
|
|
logger.info("skipping network: %s", name)
|
|
|
|
else:
|
|
|
|
network_format = source_format(network)
|
2023-03-20 01:16:52 +00:00
|
|
|
network_model = network.get("model", None)
|
2023-03-18 12:40:57 +00:00
|
|
|
network_type = network["type"]
|
|
|
|
source = network["source"]
|
|
|
|
|
|
|
|
try:
|
2023-04-12 00:29:25 +00:00
|
|
|
if network_type == "control":
|
2023-05-04 03:21:04 +00:00
|
|
|
dest, hf = fetch_model(
|
2023-04-12 00:29:25 +00:00
|
|
|
conversion,
|
|
|
|
name,
|
|
|
|
source,
|
|
|
|
format=network_format,
|
|
|
|
)
|
|
|
|
|
|
|
|
convert_diffusion_control(
|
|
|
|
conversion,
|
|
|
|
network,
|
|
|
|
dest,
|
|
|
|
)
|
2023-03-20 01:16:52 +00:00
|
|
|
if network_type == "inversion" and network_model == "concept":
|
2023-05-04 03:21:04 +00:00
|
|
|
dest, hf = fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-03-20 01:32:21 +00:00
|
|
|
name,
|
|
|
|
source,
|
2023-04-10 01:33:03 +00:00
|
|
|
dest=path.join(conversion.model_path, network_type),
|
2023-03-20 01:32:21 +00:00
|
|
|
format=network_format,
|
|
|
|
hf_hub_fetch=True,
|
|
|
|
hf_hub_filename="learned_embeds.bin",
|
2023-03-20 01:16:52 +00:00
|
|
|
)
|
|
|
|
else:
|
2023-05-04 03:21:04 +00:00
|
|
|
dest, hf = fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-03-20 01:16:52 +00:00
|
|
|
name,
|
|
|
|
source,
|
2023-04-10 01:33:03 +00:00
|
|
|
dest=path.join(conversion.model_path, network_type),
|
2023-03-20 01:16:52 +00:00
|
|
|
format=network_format,
|
|
|
|
)
|
|
|
|
|
2023-03-18 12:40:57 +00:00
|
|
|
logger.info("finished downloading network: %s -> %s", source, dest)
|
|
|
|
except Exception:
|
|
|
|
logger.exception("error fetching network %s", name)
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors.append(name)
|
2023-03-18 12:40:57 +00:00
|
|
|
|
2023-02-12 18:23:03 +00:00
|
|
|
if args.diffusion and "diffusion" in models:
|
2023-02-11 04:41:24 +00:00
|
|
|
for model in models.get("diffusion"):
|
|
|
|
model = tuple_to_diffusion(model)
|
|
|
|
name = model.get("name")
|
|
|
|
|
2023-02-09 04:35:54 +00:00
|
|
|
if name in args.skip:
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("skipping model: %s", name)
|
2023-02-09 04:35:54 +00:00
|
|
|
else:
|
2023-02-11 20:19:42 +00:00
|
|
|
model_format = source_format(model)
|
2023-02-17 13:49:45 +00:00
|
|
|
|
|
|
|
try:
|
2023-05-04 03:21:04 +00:00
|
|
|
source, hf = fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion, name, model["source"], format=model_format
|
2023-02-09 04:35:54 +00:00
|
|
|
)
|
|
|
|
|
2023-04-29 17:00:43 +00:00
|
|
|
converted, dest = convert_diffusion_diffusers(
|
|
|
|
conversion,
|
|
|
|
model,
|
|
|
|
source,
|
2023-04-30 04:05:51 +00:00
|
|
|
model_format,
|
2023-05-04 03:21:04 +00:00
|
|
|
hf=hf,
|
2023-04-29 17:00:43 +00:00
|
|
|
)
|
2023-02-22 03:40:57 +00:00
|
|
|
|
2023-03-18 12:14:22 +00:00
|
|
|
# make sure blending only happens once, not every run
|
|
|
|
if converted:
|
|
|
|
# keep track of which models have been blended
|
|
|
|
blend_models = {}
|
|
|
|
|
2023-04-10 01:33:03 +00:00
|
|
|
inversion_dest = path.join(conversion.model_path, "inversion")
|
|
|
|
lora_dest = path.join(conversion.model_path, "lora")
|
2023-03-18 12:14:22 +00:00
|
|
|
|
|
|
|
for inversion in model.get("inversions", []):
|
|
|
|
if "text_encoder" not in blend_models:
|
2023-03-18 12:41:29 +00:00
|
|
|
blend_models["text_encoder"] = load_model(
|
|
|
|
path.join(
|
2023-03-18 16:34:05 +00:00
|
|
|
dest,
|
2023-03-18 12:41:29 +00:00
|
|
|
"text_encoder",
|
2023-03-24 13:14:19 +00:00
|
|
|
ONNX_MODEL,
|
2023-03-18 12:41:29 +00:00
|
|
|
)
|
|
|
|
)
|
2023-03-18 12:14:22 +00:00
|
|
|
|
|
|
|
if "tokenizer" not in blend_models:
|
2023-03-18 12:41:29 +00:00
|
|
|
blend_models[
|
|
|
|
"tokenizer"
|
|
|
|
] = CLIPTokenizer.from_pretrained(
|
2023-03-18 16:34:05 +00:00
|
|
|
dest,
|
2023-03-18 12:41:29 +00:00
|
|
|
subfolder="tokenizer",
|
|
|
|
)
|
2023-03-18 12:14:22 +00:00
|
|
|
|
|
|
|
inversion_name = inversion["name"]
|
|
|
|
inversion_source = inversion["source"]
|
2023-03-18 15:50:48 +00:00
|
|
|
inversion_format = inversion.get("format", None)
|
2023-05-04 03:21:04 +00:00
|
|
|
inversion_source, hf = fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-03-18 16:34:05 +00:00
|
|
|
inversion_name,
|
2023-03-18 12:14:22 +00:00
|
|
|
inversion_source,
|
|
|
|
dest=inversion_dest,
|
|
|
|
)
|
|
|
|
inversion_token = inversion.get("token", inversion_name)
|
|
|
|
inversion_weight = inversion.get("weight", 1.0)
|
|
|
|
|
|
|
|
blend_textual_inversions(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-03-18 12:14:22 +00:00
|
|
|
blend_models["text_encoder"],
|
|
|
|
blend_models["tokenizer"],
|
2023-03-18 15:50:48 +00:00
|
|
|
[
|
|
|
|
(
|
|
|
|
inversion_source,
|
|
|
|
inversion_weight,
|
|
|
|
inversion_token,
|
|
|
|
inversion_format,
|
|
|
|
)
|
|
|
|
],
|
2023-03-18 12:14:22 +00:00
|
|
|
)
|
2023-02-22 03:40:57 +00:00
|
|
|
|
2023-03-18 12:14:22 +00:00
|
|
|
for lora in model.get("loras", []):
|
|
|
|
if "text_encoder" not in blend_models:
|
2023-03-18 12:41:29 +00:00
|
|
|
blend_models["text_encoder"] = load_model(
|
|
|
|
path.join(
|
2023-03-18 16:34:05 +00:00
|
|
|
dest,
|
2023-03-18 12:41:29 +00:00
|
|
|
"text_encoder",
|
2023-03-24 13:14:19 +00:00
|
|
|
ONNX_MODEL,
|
2023-03-18 12:41:29 +00:00
|
|
|
)
|
|
|
|
)
|
2023-03-18 12:14:22 +00:00
|
|
|
|
|
|
|
if "unet" not in blend_models:
|
2023-04-07 04:34:06 +00:00
|
|
|
blend_models["unet"] = load_model(
|
2023-03-24 13:14:19 +00:00
|
|
|
path.join(dest, "unet", ONNX_MODEL)
|
2023-03-18 12:41:29 +00:00
|
|
|
)
|
2023-03-18 12:14:22 +00:00
|
|
|
|
|
|
|
# load models if not loaded yet
|
|
|
|
lora_name = lora["name"]
|
|
|
|
lora_source = lora["source"]
|
2023-05-04 03:21:04 +00:00
|
|
|
lora_source, hf = fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-03-18 12:14:22 +00:00
|
|
|
f"{name}-lora-{lora_name}",
|
|
|
|
lora_source,
|
|
|
|
dest=lora_dest,
|
2023-03-18 12:01:16 +00:00
|
|
|
)
|
2023-03-18 12:14:22 +00:00
|
|
|
lora_weight = lora.get("weight", 1.0)
|
|
|
|
|
|
|
|
blend_loras(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-03-18 12:14:22 +00:00
|
|
|
blend_models["text_encoder"],
|
2023-04-07 04:34:06 +00:00
|
|
|
[(lora_source, lora_weight)],
|
2023-03-18 12:14:22 +00:00
|
|
|
"text_encoder",
|
2023-04-07 04:34:06 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
blend_loras(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-04-07 04:34:06 +00:00
|
|
|
blend_models["unet"],
|
|
|
|
[(lora_source, lora_weight)],
|
|
|
|
"unet",
|
2023-03-18 12:14:22 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if "tokenizer" in blend_models:
|
2023-03-18 16:37:16 +00:00
|
|
|
dest_path = path.join(dest, "tokenizer")
|
2023-03-18 12:14:22 +00:00
|
|
|
logger.debug("saving blended tokenizer to %s", dest_path)
|
|
|
|
blend_models["tokenizer"].save_pretrained(dest_path)
|
|
|
|
|
|
|
|
for name in ["text_encoder", "unet"]:
|
|
|
|
if name in blend_models:
|
2023-03-24 13:14:19 +00:00
|
|
|
dest_path = path.join(dest, name, ONNX_MODEL)
|
2023-03-18 12:41:29 +00:00
|
|
|
logger.debug(
|
|
|
|
"saving blended %s model to %s", name, dest_path
|
|
|
|
)
|
2023-03-18 12:14:22 +00:00
|
|
|
save_model(
|
|
|
|
blend_models[name],
|
|
|
|
dest_path,
|
|
|
|
save_as_external_data=True,
|
|
|
|
all_tensors_to_one_file=True,
|
2023-03-24 13:14:19 +00:00
|
|
|
location=ONNX_WEIGHTS,
|
2023-03-18 12:14:22 +00:00
|
|
|
)
|
2023-03-18 12:01:16 +00:00
|
|
|
|
2023-03-17 03:29:07 +00:00
|
|
|
except Exception:
|
|
|
|
logger.exception(
|
|
|
|
"error converting diffusion model %s",
|
2023-03-01 04:30:29 +00:00
|
|
|
name,
|
|
|
|
)
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors.append(name)
|
2023-02-17 13:49:45 +00:00
|
|
|
|
2023-02-12 18:23:03 +00:00
|
|
|
if args.upscaling and "upscaling" in models:
|
2023-02-11 04:41:24 +00:00
|
|
|
for model in models.get("upscaling"):
|
|
|
|
model = tuple_to_upscaling(model)
|
|
|
|
name = model.get("name")
|
|
|
|
|
|
|
|
if name in args.skip:
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("skipping model: %s", name)
|
2023-02-09 04:35:54 +00:00
|
|
|
else:
|
2023-02-11 20:19:42 +00:00
|
|
|
model_format = source_format(model)
|
2023-02-17 13:49:45 +00:00
|
|
|
|
|
|
|
try:
|
2023-05-04 03:21:04 +00:00
|
|
|
source, hf = fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion, name, model["source"], format=model_format
|
2023-02-17 13:49:45 +00:00
|
|
|
)
|
2023-04-10 22:49:56 +00:00
|
|
|
model_type = model.get("model", "resrgan")
|
|
|
|
if model_type == "bsrgan":
|
|
|
|
convert_upscaling_bsrgan(conversion, model, source)
|
|
|
|
elif model_type == "resrgan":
|
|
|
|
convert_upscale_resrgan(conversion, model, source)
|
|
|
|
elif model_type == "swinir":
|
|
|
|
convert_upscaling_swinir(conversion, model, source)
|
|
|
|
else:
|
|
|
|
logger.error(
|
|
|
|
"unknown upscaling model type %s for %s", model_type, name
|
|
|
|
)
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors.append(name)
|
2023-03-17 03:29:07 +00:00
|
|
|
except Exception:
|
|
|
|
logger.exception(
|
|
|
|
"error converting upscaling model %s",
|
2023-03-01 04:30:29 +00:00
|
|
|
name,
|
|
|
|
)
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors.append(name)
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-02-12 18:23:03 +00:00
|
|
|
if args.correction and "correction" in models:
|
2023-02-11 04:41:24 +00:00
|
|
|
for model in models.get("correction"):
|
|
|
|
model = tuple_to_correction(model)
|
|
|
|
name = model.get("name")
|
|
|
|
|
|
|
|
if name in args.skip:
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("skipping model: %s", name)
|
2023-02-09 04:35:54 +00:00
|
|
|
else:
|
2023-02-11 20:19:42 +00:00
|
|
|
model_format = source_format(model)
|
2023-02-17 13:49:45 +00:00
|
|
|
try:
|
2023-05-04 03:21:04 +00:00
|
|
|
source, hf = fetch_model(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion, name, model["source"], format=model_format
|
2023-02-17 13:49:45 +00:00
|
|
|
)
|
2023-04-10 22:49:56 +00:00
|
|
|
model_type = model.get("model", "gfpgan")
|
|
|
|
if model_type == "gfpgan":
|
|
|
|
convert_correction_gfpgan(conversion, model, source)
|
|
|
|
else:
|
|
|
|
logger.error(
|
|
|
|
"unknown correction model type %s for %s", model_type, name
|
|
|
|
)
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors.append(name)
|
2023-03-17 03:29:07 +00:00
|
|
|
except Exception:
|
|
|
|
logger.exception(
|
|
|
|
"error converting correction model %s",
|
2023-03-01 04:30:29 +00:00
|
|
|
name,
|
|
|
|
)
|
2023-04-30 03:58:58 +00:00
|
|
|
model_errors.append(name)
|
|
|
|
|
|
|
|
if len(model_errors) > 0:
|
|
|
|
logger.error("error while converting models: %s", model_errors)
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-02-18 04:49:13 +00:00
|
|
|
|
2023-05-09 02:59:27 +00:00
|
|
|
def main(args=None) -> int:
|
2023-02-09 04:35:54 +00:00
|
|
|
parser = ArgumentParser(
|
|
|
|
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
|
|
|
)
|
|
|
|
|
|
|
|
# model groups
|
2023-03-19 14:09:03 +00:00
|
|
|
parser.add_argument("--networks", action="store_true", default=True)
|
|
|
|
parser.add_argument("--sources", action="store_true", default=True)
|
2023-02-09 04:35:54 +00:00
|
|
|
parser.add_argument("--correction", action="store_true", default=False)
|
|
|
|
parser.add_argument("--diffusion", action="store_true", default=False)
|
|
|
|
parser.add_argument("--upscaling", action="store_true", default=False)
|
|
|
|
|
|
|
|
# extra models
|
|
|
|
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
2023-03-09 03:38:17 +00:00
|
|
|
parser.add_argument("--prune", nargs="*", type=str, default=[])
|
2023-02-09 04:35:54 +00:00
|
|
|
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
|
|
|
|
|
|
|
# export options
|
|
|
|
parser.add_argument(
|
|
|
|
"--half",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
2023-04-05 01:02:13 +00:00
|
|
|
help="Export models for half precision, smaller and faster on most GPUs.",
|
2023-02-09 04:35:54 +00:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--opset",
|
2023-05-15 01:04:43 +00:00
|
|
|
default=DEFAULT_OPSET,
|
2023-02-09 04:35:54 +00:00
|
|
|
type=int,
|
|
|
|
help="The version of the ONNX operator set to use.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--token",
|
|
|
|
type=str,
|
|
|
|
help="HuggingFace token with read permissions for downloading models.",
|
|
|
|
)
|
|
|
|
|
2023-05-09 02:56:05 +00:00
|
|
|
args = parser.parse_args(args=args)
|
2023-02-09 04:35:54 +00:00
|
|
|
logger.info("CLI arguments: %s", args)
|
|
|
|
|
2023-04-10 01:33:03 +00:00
|
|
|
server = ConversionContext.from_environ()
|
|
|
|
server.half = args.half or "onnx-fp16" in server.optimizations
|
|
|
|
server.opset = args.opset
|
|
|
|
server.token = args.token
|
2023-04-10 01:34:10 +00:00
|
|
|
logger.info(
|
|
|
|
"converting models in %s using %s", server.model_path, server.training_device
|
|
|
|
)
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-04-10 01:33:03 +00:00
|
|
|
if not path.exists(server.model_path):
|
|
|
|
logger.info("model path does not existing, creating: %s", server.model_path)
|
|
|
|
makedirs(server.model_path)
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("converting base models")
|
2023-04-10 01:33:03 +00:00
|
|
|
convert_models(server, args, base_models)
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-03-05 04:57:31 +00:00
|
|
|
extras = []
|
2023-04-10 01:33:03 +00:00
|
|
|
extras.extend(server.extra_models)
|
2023-03-05 04:57:31 +00:00
|
|
|
extras.extend(args.extras)
|
|
|
|
extras = list(set(extras))
|
|
|
|
extras.sort()
|
|
|
|
logger.debug("loading extra files: %s", extras)
|
|
|
|
|
2023-05-04 03:15:17 +00:00
|
|
|
extra_schema = load_config("./schemas/extras.yaml")
|
2023-03-05 04:57:31 +00:00
|
|
|
|
|
|
|
for file in extras:
|
2023-02-09 04:35:54 +00:00
|
|
|
if file is not None and file != "":
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("loading extra models from %s", file)
|
2023-02-09 04:35:54 +00:00
|
|
|
try:
|
2023-05-04 03:15:17 +00:00
|
|
|
data = load_config(file)
|
2023-03-05 04:57:31 +00:00
|
|
|
logger.debug("validating extras file %s", data)
|
2023-02-11 04:41:24 +00:00
|
|
|
try:
|
2023-03-05 04:57:31 +00:00
|
|
|
validate(data, extra_schema)
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("converting extra models")
|
2023-04-10 01:33:03 +00:00
|
|
|
convert_models(server, args, data)
|
2023-03-17 03:29:07 +00:00
|
|
|
except ValidationError:
|
|
|
|
logger.exception("invalid data in extras file")
|
|
|
|
except Exception:
|
|
|
|
logger.exception("error converting extra models")
|
2023-02-09 04:35:54 +00:00
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
exit(main())
|