feat(api): convert from SD checkpoints (#117)
This commit is contained in:
parent
cd4a0f10b0
commit
4f71348f98
|
@ -0,0 +1,177 @@
|
||||||
|
from .correction_gfpgan import convert_correction_gfpgan
|
||||||
|
from .diffusion_original import convert_diffusion_original
|
||||||
|
from .diffusion_stable import convert_diffusion_stable
|
||||||
|
from .upscale_resrgan import convert_upscale_resrgan
|
||||||
|
from .utils import ConversionContext
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from json import loads
|
||||||
|
from logging import getLogger
|
||||||
|
from os import environ, makedirs, path
|
||||||
|
from sys import exit
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
||||||
|
)
|
||||||
|
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||||
|
)
|
||||||
|
|
||||||
|
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# recommended models
|
||||||
|
base_models: Models = {
|
||||||
|
"diffusion": [
|
||||||
|
# v1.x
|
||||||
|
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
||||||
|
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
||||||
|
# v2.x
|
||||||
|
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
||||||
|
(
|
||||||
|
"stable-diffusion-onnx-v2-inpainting",
|
||||||
|
"stabilityai/stable-diffusion-2-inpainting",
|
||||||
|
),
|
||||||
|
# TODO: should have its own converter
|
||||||
|
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
||||||
|
# TODO: testing safetensors
|
||||||
|
("diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"),
|
||||||
|
("diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"),
|
||||||
|
],
|
||||||
|
"correction": [
|
||||||
|
(
|
||||||
|
"correction-gfpgan-v1-3",
|
||||||
|
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"correction-codeformer",
|
||||||
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"upscaling": [
|
||||||
|
(
|
||||||
|
"upscaling-real-esrgan-x2-plus",
|
||||||
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"upscaling-real-esrgan-x4-plus",
|
||||||
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"upscaling-real-esrgan-x4-v3",
|
||||||
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||||
|
4,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||||
|
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(args, ctx: ConversionContext, models: Models):
|
||||||
|
if args.diffusion:
|
||||||
|
for source in models.get("diffusion"):
|
||||||
|
name, file = source
|
||||||
|
if name in args.skip:
|
||||||
|
logger.info("Skipping model: %s", source[0])
|
||||||
|
else:
|
||||||
|
if file.endswith(".safetensors") or file.endswith(".ckpt"):
|
||||||
|
convert_diffusion_original(ctx, *source, args.opset, args.half)
|
||||||
|
else:
|
||||||
|
# TODO: make this a parameter in the JSON/dict
|
||||||
|
single_vae = "upscaling" in source[0]
|
||||||
|
convert_diffusion_stable(
|
||||||
|
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.upscaling:
|
||||||
|
for source in models.get("upscaling"):
|
||||||
|
if source[0] in args.skip:
|
||||||
|
logger.info("Skipping model: %s", source[0])
|
||||||
|
else:
|
||||||
|
convert_upscale_resrgan(ctx, *source, args.opset)
|
||||||
|
|
||||||
|
if args.correction:
|
||||||
|
for source in models.get("correction"):
|
||||||
|
if source[0] in args.skip:
|
||||||
|
logger.info("Skipping model: %s", source[0])
|
||||||
|
else:
|
||||||
|
convert_correction_gfpgan(ctx, *source, args.opset)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = ArgumentParser(
|
||||||
|
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||||
|
)
|
||||||
|
|
||||||
|
# model groups
|
||||||
|
parser.add_argument("--correction", action="store_true", default=False)
|
||||||
|
parser.add_argument("--diffusion", action="store_true", default=False)
|
||||||
|
parser.add_argument("--upscaling", action="store_true", default=False)
|
||||||
|
|
||||||
|
# extra models
|
||||||
|
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
||||||
|
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
||||||
|
|
||||||
|
# export options
|
||||||
|
parser.add_argument(
|
||||||
|
"--half",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Export models for half precision, faster on some Nvidia cards.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--opset",
|
||||||
|
default=14,
|
||||||
|
type=int,
|
||||||
|
help="The version of the ONNX operator set to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token",
|
||||||
|
type=str,
|
||||||
|
help="HuggingFace token with read permissions for downloading models.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
|
ctx = ConversionContext(model_path, training_device)
|
||||||
|
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||||
|
|
||||||
|
if not path.exists(model_path):
|
||||||
|
logger.info("Model path does not existing, creating: %s", model_path)
|
||||||
|
makedirs(model_path)
|
||||||
|
|
||||||
|
logger.info("Converting base models.")
|
||||||
|
load_models(args, ctx, base_models)
|
||||||
|
|
||||||
|
for file in args.extras:
|
||||||
|
if file is not None and file != "":
|
||||||
|
logger.info("Loading extra models from %s", file)
|
||||||
|
try:
|
||||||
|
with open(file, "r") as f:
|
||||||
|
data = loads(f.read())
|
||||||
|
logger.info("Converting extra models.")
|
||||||
|
load_models(args, ctx, data)
|
||||||
|
except Exception as err:
|
||||||
|
logger.error("Error converting extra models: %s", err)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
|
@ -0,0 +1,68 @@
|
||||||
|
import torch
|
||||||
|
from shutil import copyfile
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
from torch.onnx import export
|
||||||
|
from os import path
|
||||||
|
from logging import getLogger
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from .utils import ConversionContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
||||||
|
dest_path = path.join(ctx.model_path, name + ".pth")
|
||||||
|
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
||||||
|
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
||||||
|
|
||||||
|
if path.isfile(dest_onnx):
|
||||||
|
logger.info("ONNX model already exists, skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not path.isfile(dest_path):
|
||||||
|
logger.info("PTH model not found, downloading...")
|
||||||
|
download_path = load_file_from_url(
|
||||||
|
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||||
|
)
|
||||||
|
copyfile(download_path, dest_path)
|
||||||
|
|
||||||
|
logger.info("loading and training model")
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
||||||
|
# TODO: make sure strict=False is safe here
|
||||||
|
if "params_ema" in torch_model:
|
||||||
|
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
|
||||||
|
model.to(ctx.training_device).train(False)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
||||||
|
input_names = ["data"]
|
||||||
|
output_names = ["output"]
|
||||||
|
dynamic_axes = {
|
||||||
|
"data": {2: "width", 3: "height"},
|
||||||
|
"output": {2: "width", 3: "height"},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||||
|
export(
|
||||||
|
model,
|
||||||
|
rng,
|
||||||
|
dest_onnx,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
dynamic_axes=dynamic_axes,
|
||||||
|
opset_version=opset,
|
||||||
|
export_params=True,
|
||||||
|
)
|
||||||
|
logger.info("GFPGAN exported to ONNX successfully.")
|
File diff suppressed because it is too large
Load Diff
|
@ -1,208 +1,23 @@
|
||||||
import warnings
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from json import loads
|
|
||||||
from logging import getLogger
|
|
||||||
from os import environ, makedirs, mkdir, path
|
|
||||||
from pathlib import Path
|
|
||||||
from shutil import copyfile, rmtree
|
|
||||||
from sys import exit
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
OnnxRuntimeModel,
|
OnnxRuntimeModel,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
|
||||||
)
|
)
|
||||||
from onnx import load, save_model
|
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
from shutil import rmtree
|
||||||
warnings.filterwarnings(
|
import torch
|
||||||
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
from os import path, mkdir
|
||||||
)
|
from pathlib import Path
|
||||||
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
|
from onnx import load, save_model
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
|
||||||
)
|
|
||||||
|
|
||||||
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
from .utils import ConversionContext
|
||||||
|
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# recommended models
|
|
||||||
base_models: Models = {
|
|
||||||
"diffusion": [
|
|
||||||
# v1.x
|
|
||||||
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
|
||||||
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
|
||||||
# v2.x
|
|
||||||
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
|
||||||
(
|
|
||||||
"stable-diffusion-onnx-v2-inpainting",
|
|
||||||
"stabilityai/stable-diffusion-2-inpainting",
|
|
||||||
),
|
|
||||||
# TODO: should have its own converter
|
|
||||||
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
|
||||||
],
|
|
||||||
"correction": [
|
|
||||||
(
|
|
||||||
"correction-gfpgan-v1-3",
|
|
||||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
|
||||||
4,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"correction-codeformer",
|
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
"upscaling": [
|
|
||||||
(
|
|
||||||
"upscaling-real-esrgan-x2-plus",
|
|
||||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"upscaling-real-esrgan-x4-plus",
|
|
||||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
|
||||||
4,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"upscaling-real-esrgan-x4-v3",
|
|
||||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
|
||||||
4,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
|
||||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
map_location = torch.device(training_device)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
|
||||||
dest_path = path.join(model_path, name + ".pth")
|
|
||||||
dest_onnx = path.join(model_path, name + ".onnx")
|
|
||||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
|
||||||
|
|
||||||
if path.isfile(dest_onnx):
|
|
||||||
logger.info("ONNX model already exists, skipping.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not path.isfile(dest_path):
|
|
||||||
logger.info("PTH model not found, downloading...")
|
|
||||||
download_path = load_file_from_url(
|
|
||||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
|
||||||
)
|
|
||||||
copyfile(download_path, dest_path)
|
|
||||||
|
|
||||||
logger.info("loading and training model")
|
|
||||||
model = RRDBNet(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_model = torch.load(dest_path, map_location=map_location)
|
|
||||||
if "params_ema" in torch_model:
|
|
||||||
model.load_state_dict(torch_model["params_ema"])
|
|
||||||
else:
|
|
||||||
model.load_state_dict(torch_model["params"], strict=False)
|
|
||||||
|
|
||||||
model.to(training_device).train(False)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
|
||||||
input_names = ["data"]
|
|
||||||
output_names = ["output"]
|
|
||||||
dynamic_axes = {
|
|
||||||
"data": {2: "width", 3: "height"},
|
|
||||||
"output": {2: "width", 3: "height"},
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
|
||||||
export(
|
|
||||||
model,
|
|
||||||
rng,
|
|
||||||
dest_onnx,
|
|
||||||
input_names=input_names,
|
|
||||||
output_names=output_names,
|
|
||||||
dynamic_axes=dynamic_axes,
|
|
||||||
opset_version=opset,
|
|
||||||
export_params=True,
|
|
||||||
)
|
|
||||||
logger.info("Real ESRGAN exported to ONNX successfully.")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
|
||||||
dest_path = path.join(model_path, name + ".pth")
|
|
||||||
dest_onnx = path.join(model_path, name + ".onnx")
|
|
||||||
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
|
||||||
|
|
||||||
if path.isfile(dest_onnx):
|
|
||||||
logger.info("ONNX model already exists, skipping.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not path.isfile(dest_path):
|
|
||||||
logger.info("PTH model not found, downloading...")
|
|
||||||
download_path = load_file_from_url(
|
|
||||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
|
||||||
)
|
|
||||||
copyfile(download_path, dest_path)
|
|
||||||
|
|
||||||
logger.info("loading and training model")
|
|
||||||
model = RRDBNet(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_model = torch.load(dest_path, map_location=map_location)
|
|
||||||
# TODO: make sure strict=False is safe here
|
|
||||||
if "params_ema" in torch_model:
|
|
||||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
|
||||||
else:
|
|
||||||
model.load_state_dict(torch_model["params"], strict=False)
|
|
||||||
|
|
||||||
model.to(training_device).train(False)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
|
||||||
input_names = ["data"]
|
|
||||||
output_names = ["output"]
|
|
||||||
dynamic_axes = {
|
|
||||||
"data": {2: "width", 3: "height"},
|
|
||||||
"output": {2: "width", 3: "height"},
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
|
||||||
export(
|
|
||||||
model,
|
|
||||||
rng,
|
|
||||||
dest_onnx,
|
|
||||||
input_names=input_names,
|
|
||||||
output_names=output_names,
|
|
||||||
dynamic_axes=dynamic_axes,
|
|
||||||
opset_version=opset,
|
|
||||||
export_params=True,
|
|
||||||
)
|
|
||||||
logger.info("GFPGAN exported to ONNX successfully.")
|
|
||||||
|
|
||||||
|
|
||||||
def onnx_export(
|
def onnx_export(
|
||||||
model,
|
model,
|
||||||
model_args: tuple,
|
model_args: tuple,
|
||||||
|
@ -231,33 +46,39 @@ def onnx_export(
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffuser(
|
def convert_diffusion_stable(
|
||||||
name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False
|
ctx: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
url: str,
|
||||||
|
opset: int,
|
||||||
|
half: bool,
|
||||||
|
token: str,
|
||||||
|
single_vae: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
dtype = torch.float16 if half else torch.float32
|
dtype = torch.float16 if half else torch.float32
|
||||||
dest_path = path.join(model_path, name)
|
dest_path = path.join(ctx.model_path, name)
|
||||||
|
|
||||||
# diffusers go into a directory rather than .onnx file
|
# diffusers go into a directory rather than .onnx file
|
||||||
logger.info("converting Diffusers model: %s -> %s/", name, dest_path)
|
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path)
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.info("converting model with single VAE")
|
logger.info("converting model with single VAE")
|
||||||
|
|
||||||
if path.isdir(dest_path):
|
if path.exists(dest_path):
|
||||||
logger.info("ONNX model already exists, skipping.")
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if half and training_device != "cuda":
|
if half and ctx.training_device != "cuda":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Half precision model export is only supported on GPUs with CUDA"
|
"Half precision model export is only supported on GPUs with CUDA"
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
url, torch_dtype=dtype, use_auth_token=token
|
url, torch_dtype=dtype, use_auth_token=token
|
||||||
).to(training_device)
|
).to(ctx.training_device)
|
||||||
output_path = Path(dest_path)
|
output_path = Path(dest_path)
|
||||||
|
|
||||||
# TEXT ENCODER
|
# TEXT ENCODER
|
||||||
|
@ -273,7 +94,7 @@ def convert_diffuser(
|
||||||
onnx_export(
|
onnx_export(
|
||||||
pipeline.text_encoder,
|
pipeline.text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)),
|
model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)),
|
||||||
output_path=output_path / "text_encoder" / "model.onnx",
|
output_path=output_path / "text_encoder" / "model.onnx",
|
||||||
ordered_input_names=["input_ids"],
|
ordered_input_names=["input_ids"],
|
||||||
output_names=["last_hidden_state", "pooler_output"],
|
output_names=["last_hidden_state", "pooler_output"],
|
||||||
|
@ -302,11 +123,11 @@ def convert_diffuser(
|
||||||
pipeline.unet,
|
pipeline.unet,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||||
device=training_device, dtype=dtype
|
device=ctx.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
torch.randn(2).to(device=training_device, dtype=dtype),
|
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
|
||||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||||
device=training_device, dtype=dtype
|
device=ctx.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
unet_scale,
|
unet_scale,
|
||||||
),
|
),
|
||||||
|
@ -353,7 +174,7 @@ def convert_diffuser(
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(
|
torch.randn(
|
||||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
).to(device=training_device, dtype=dtype),
|
).to(device=ctx.training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae" / "model.onnx",
|
output_path=output_path / "vae" / "model.onnx",
|
||||||
|
@ -377,7 +198,7 @@ def convert_diffuser(
|
||||||
vae_encoder,
|
vae_encoder,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||||
device=training_device, dtype=dtype
|
device=ctx.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
@ -401,7 +222,7 @@ def convert_diffuser(
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(
|
torch.randn(
|
||||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
).to(device=training_device, dtype=dtype),
|
).to(device=ctx.training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||||
|
@ -429,9 +250,9 @@ def convert_diffuser(
|
||||||
clip_num_channels,
|
clip_num_channels,
|
||||||
clip_image_size,
|
clip_image_size,
|
||||||
clip_image_size,
|
clip_image_size,
|
||||||
).to(device=training_device, dtype=dtype),
|
).to(device=ctx.training_device, dtype=dtype),
|
||||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
||||||
device=training_device, dtype=dtype
|
device=ctx.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
output_path=output_path / "safety_checker" / "model.onnx",
|
output_path=output_path / "safety_checker" / "model.onnx",
|
||||||
|
@ -453,7 +274,7 @@ def convert_diffuser(
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
onnx_pipeline = StableDiffusionUpscalePipeline(
|
onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
|
||||||
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
|
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
|
||||||
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
||||||
tokenizer=pipeline.tokenizer,
|
tokenizer=pipeline.tokenizer,
|
||||||
|
@ -483,7 +304,7 @@ def convert_diffuser(
|
||||||
del onnx_pipeline
|
del onnx_pipeline
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
_ = StableDiffusionUpscalePipeline.from_pretrained(
|
_ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||||
output_path, provider="CPUExecutionProvider"
|
output_path, provider="CPUExecutionProvider"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -493,89 +314,3 @@ def convert_diffuser(
|
||||||
|
|
||||||
logger.info("ONNX pipeline is loadable")
|
logger.info("ONNX pipeline is loadable")
|
||||||
|
|
||||||
|
|
||||||
def load_models(args, models: Models):
|
|
||||||
if args.diffusion:
|
|
||||||
for source in models.get("diffusion"):
|
|
||||||
if source[0] in args.skip:
|
|
||||||
logger.info("Skipping model: %s", source[0])
|
|
||||||
else:
|
|
||||||
single_vae = "upscaling" in source[0]
|
|
||||||
convert_diffuser(
|
|
||||||
*source, args.opset, args.half, args.token, single_vae=single_vae
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.upscaling:
|
|
||||||
for source in models.get("upscaling"):
|
|
||||||
if source[0] in args.skip:
|
|
||||||
logger.info("Skipping model: %s", source[0])
|
|
||||||
else:
|
|
||||||
convert_real_esrgan(*source, args.opset)
|
|
||||||
|
|
||||||
if args.correction:
|
|
||||||
for source in models.get("correction"):
|
|
||||||
if source[0] in args.skip:
|
|
||||||
logger.info("Skipping model: %s", source[0])
|
|
||||||
else:
|
|
||||||
convert_gfpgan(*source, args.opset)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
parser = ArgumentParser(
|
|
||||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
|
||||||
)
|
|
||||||
|
|
||||||
# model groups
|
|
||||||
parser.add_argument("--correction", action="store_true", default=False)
|
|
||||||
parser.add_argument("--diffusion", action="store_true", default=False)
|
|
||||||
parser.add_argument("--upscaling", action="store_true", default=False)
|
|
||||||
|
|
||||||
# extra models
|
|
||||||
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
|
||||||
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
|
||||||
|
|
||||||
# export options
|
|
||||||
parser.add_argument(
|
|
||||||
"--half",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Export models for half precision, faster on some Nvidia cards.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--opset",
|
|
||||||
default=14,
|
|
||||||
type=int,
|
|
||||||
help="The version of the ONNX operator set to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--token",
|
|
||||||
type=str,
|
|
||||||
help="HuggingFace token with read permissions for downloading models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
logger.info("CLI arguments: %s", args)
|
|
||||||
|
|
||||||
if not path.exists(model_path):
|
|
||||||
logger.info("Model path does not existing, creating: %s", model_path)
|
|
||||||
makedirs(model_path)
|
|
||||||
|
|
||||||
logger.info("Converting base models.")
|
|
||||||
load_models(args, base_models)
|
|
||||||
|
|
||||||
for file in args.extras:
|
|
||||||
if file is not None and file != "":
|
|
||||||
logger.info("Loading extra models from %s", file)
|
|
||||||
try:
|
|
||||||
with open(file, "r") as f:
|
|
||||||
data = loads(f.read())
|
|
||||||
logger.info("Converting extra models.")
|
|
||||||
load_models(args, data)
|
|
||||||
except Exception as err:
|
|
||||||
logger.error("Error converting extra models: %s", err)
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
exit(main())
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
import torch
|
||||||
|
from shutil import copyfile
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
from torch.onnx import export
|
||||||
|
from os import path
|
||||||
|
from logging import getLogger
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from .utils import ConversionContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
||||||
|
dest_path = path.join(ctx.model_path, name + ".pth")
|
||||||
|
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
||||||
|
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
||||||
|
|
||||||
|
if path.isfile(dest_onnx):
|
||||||
|
logger.info("ONNX model already exists, skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not path.isfile(dest_path):
|
||||||
|
logger.info("PTH model not found, downloading...")
|
||||||
|
download_path = load_file_from_url(
|
||||||
|
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||||
|
)
|
||||||
|
copyfile(download_path, dest_path)
|
||||||
|
|
||||||
|
logger.info("loading and training model")
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
||||||
|
if "params_ema" in torch_model:
|
||||||
|
model.load_state_dict(torch_model["params_ema"])
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
|
||||||
|
model.to(ctx.training_device).train(False)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
||||||
|
input_names = ["data"]
|
||||||
|
output_names = ["output"]
|
||||||
|
dynamic_axes = {
|
||||||
|
"data": {2: "width", 3: "height"},
|
||||||
|
"output": {2: "width", 3: "height"},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||||
|
export(
|
||||||
|
model,
|
||||||
|
rng,
|
||||||
|
dest_onnx,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
dynamic_axes=dynamic_axes,
|
||||||
|
opset_version=opset,
|
||||||
|
export_params=True,
|
||||||
|
)
|
||||||
|
logger.info("Real ESRGAN exported to ONNX successfully.")
|
|
@ -0,0 +1,7 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class ConversionContext:
|
||||||
|
def __init__(self, model_path: str, device: str) -> None:
|
||||||
|
self.model_path = model_path
|
||||||
|
self.training_device = device
|
||||||
|
self.map_location = torch.device(device)
|
|
@ -96,7 +96,7 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(pipe, "to"):
|
if device is not None and hasattr(pipe, "to"):
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device.torch_device())
|
||||||
|
|
||||||
last_pipeline_instance = pipe
|
last_pipeline_instance = pipe
|
||||||
last_pipeline_options = options
|
last_pipeline_options = options
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "images"
|
||||||
|
cond_stage_key: "input_ids"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
Loading…
Reference in New Issue