feat(api): convert from SD checkpoints (#117)
This commit is contained in:
parent
cd4a0f10b0
commit
4f71348f98
|
@ -0,0 +1,177 @@
|
|||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .diffusion_original import convert_diffusion_original
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .upscale_resrgan import convert_upscale_resrgan
|
||||
from .utils import ConversionContext
|
||||
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from json import loads
|
||||
from logging import getLogger
|
||||
from os import environ, makedirs, path
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||
warnings.filterwarnings(
|
||||
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||
)
|
||||
|
||||
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
"diffusion": [
|
||||
# v1.x
|
||||
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
||||
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
||||
# v2.x
|
||||
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
||||
(
|
||||
"stable-diffusion-onnx-v2-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
),
|
||||
# TODO: should have its own converter
|
||||
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
||||
# TODO: testing safetensors
|
||||
("diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"),
|
||||
("diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"),
|
||||
],
|
||||
"correction": [
|
||||
(
|
||||
"correction-gfpgan-v1-3",
|
||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
||||
4,
|
||||
),
|
||||
(
|
||||
"correction-codeformer",
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
|
||||
1,
|
||||
),
|
||||
],
|
||||
"upscaling": [
|
||||
(
|
||||
"upscaling-real-esrgan-x2-plus",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
2,
|
||||
),
|
||||
(
|
||||
"upscaling-real-esrgan-x4-plus",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
4,
|
||||
),
|
||||
(
|
||||
"upscaling-real-esrgan-x4-v3",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
4,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def load_models(args, ctx: ConversionContext, models: Models):
|
||||
if args.diffusion:
|
||||
for source in models.get("diffusion"):
|
||||
name, file = source
|
||||
if name in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
if file.endswith(".safetensors") or file.endswith(".ckpt"):
|
||||
convert_diffusion_original(ctx, *source, args.opset, args.half)
|
||||
else:
|
||||
# TODO: make this a parameter in the JSON/dict
|
||||
single_vae = "upscaling" in source[0]
|
||||
convert_diffusion_stable(
|
||||
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae
|
||||
)
|
||||
|
||||
if args.upscaling:
|
||||
for source in models.get("upscaling"):
|
||||
if source[0] in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
convert_upscale_resrgan(ctx, *source, args.opset)
|
||||
|
||||
if args.correction:
|
||||
for source in models.get("correction"):
|
||||
if source[0] in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
convert_correction_gfpgan(ctx, *source, args.opset)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = ArgumentParser(
|
||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||
)
|
||||
|
||||
# model groups
|
||||
parser.add_argument("--correction", action="store_true", default=False)
|
||||
parser.add_argument("--diffusion", action="store_true", default=False)
|
||||
parser.add_argument("--upscaling", action="store_true", default=False)
|
||||
|
||||
# extra models
|
||||
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
||||
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
||||
|
||||
# export options
|
||||
parser.add_argument(
|
||||
"--half",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Export models for half precision, faster on some Nvidia cards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
default=14,
|
||||
type=int,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
help="HuggingFace token with read permissions for downloading models.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
ctx = ConversionContext(model_path, training_device)
|
||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||
|
||||
if not path.exists(model_path):
|
||||
logger.info("Model path does not existing, creating: %s", model_path)
|
||||
makedirs(model_path)
|
||||
|
||||
logger.info("Converting base models.")
|
||||
load_models(args, ctx, base_models)
|
||||
|
||||
for file in args.extras:
|
||||
if file is not None and file != "":
|
||||
logger.info("Loading extra models from %s", file)
|
||||
try:
|
||||
with open(file, "r") as f:
|
||||
data = loads(f.read())
|
||||
logger.info("Converting extra models.")
|
||||
load_models(args, ctx, data)
|
||||
except Exception as err:
|
||||
logger.error("Error converting extra models: %s", err)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
|
@ -0,0 +1,68 @@
|
|||
import torch
|
||||
from shutil import copyfile
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.onnx import export
|
||||
from os import path
|
||||
from logging import getLogger
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from .utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(ctx.model_path, name + ".pth")
|
||||
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
||||
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
||||
# TODO: make sure strict=False is safe here
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(ctx.training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"data": {2: "width", 3: "height"},
|
||||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest_onnx,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("GFPGAN exported to ONNX successfully.")
|
File diff suppressed because it is too large
Load Diff
|
@ -1,208 +1,23 @@
|
|||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from json import loads
|
||||
from logging import getLogger
|
||||
from os import environ, makedirs, mkdir, path
|
||||
from pathlib import Path
|
||||
from shutil import copyfile, rmtree
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from diffusers import (
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from onnx import load, save_model
|
||||
from torch.onnx import export
|
||||
from logging import getLogger
|
||||
|
||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||
warnings.filterwarnings(
|
||||
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||
)
|
||||
from shutil import rmtree
|
||||
import torch
|
||||
from os import path, mkdir
|
||||
from pathlib import Path
|
||||
from onnx import load, save_model
|
||||
|
||||
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||
from .utils import ConversionContext
|
||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
"diffusion": [
|
||||
# v1.x
|
||||
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
||||
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
||||
# v2.x
|
||||
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
||||
(
|
||||
"stable-diffusion-onnx-v2-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
),
|
||||
# TODO: should have its own converter
|
||||
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
||||
],
|
||||
"correction": [
|
||||
(
|
||||
"correction-gfpgan-v1-3",
|
||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
|
||||
4,
|
||||
),
|
||||
(
|
||||
"correction-codeformer",
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
|
||||
1,
|
||||
),
|
||||
],
|
||||
"upscaling": [
|
||||
(
|
||||
"upscaling-real-esrgan-x2-plus",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
2,
|
||||
),
|
||||
(
|
||||
"upscaling-real-esrgan-x4-plus",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
4,
|
||||
),
|
||||
(
|
||||
"upscaling-real-esrgan-x4-v3",
|
||||
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
4,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
map_location = torch.device(training_device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(model_path, name + ".pth")
|
||||
dest_onnx = path.join(model_path, name + ".onnx")
|
||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"])
|
||||
else:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"data": {2: "width", 3: "height"},
|
||||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest_onnx,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("Real ESRGAN exported to ONNX successfully.")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(model_path, name + ".pth")
|
||||
dest_onnx = path.join(model_path, name + ".onnx")
|
||||
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=map_location)
|
||||
# TODO: make sure strict=False is safe here
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"data": {2: "width", 3: "height"},
|
||||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest_onnx,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("GFPGAN exported to ONNX successfully.")
|
||||
|
||||
|
||||
def onnx_export(
|
||||
model,
|
||||
model_args: tuple,
|
||||
|
@ -231,33 +46,39 @@ def onnx_export(
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffuser(
|
||||
name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False
|
||||
def convert_diffusion_stable(
|
||||
ctx: ConversionContext,
|
||||
name: str,
|
||||
url: str,
|
||||
opset: int,
|
||||
half: bool,
|
||||
token: str,
|
||||
single_vae: bool = False,
|
||||
):
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
dtype = torch.float16 if half else torch.float32
|
||||
dest_path = path.join(model_path, name)
|
||||
dest_path = path.join(ctx.model_path, name)
|
||||
|
||||
# diffusers go into a directory rather than .onnx file
|
||||
logger.info("converting Diffusers model: %s -> %s/", name, dest_path)
|
||||
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path)
|
||||
|
||||
if single_vae:
|
||||
logger.info("converting model with single VAE")
|
||||
|
||||
if path.isdir(dest_path):
|
||||
if path.exists(dest_path):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if half and training_device != "cuda":
|
||||
if half and ctx.training_device != "cuda":
|
||||
raise ValueError(
|
||||
"Half precision model export is only supported on GPUs with CUDA"
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
url, torch_dtype=dtype, use_auth_token=token
|
||||
).to(training_device)
|
||||
).to(ctx.training_device)
|
||||
output_path = Path(dest_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
|
@ -273,7 +94,7 @@ def convert_diffuser(
|
|||
onnx_export(
|
||||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)),
|
||||
model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
|
@ -302,11 +123,11 @@ def convert_diffuser(
|
|||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype
|
||||
device=ctx.training_device, dtype=dtype
|
||||
),
|
||||
torch.randn(2).to(device=training_device, dtype=dtype),
|
||||
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||
device=training_device, dtype=dtype
|
||||
device=ctx.training_device, dtype=dtype
|
||||
),
|
||||
unet_scale,
|
||||
),
|
||||
|
@ -353,7 +174,7 @@ def convert_diffuser(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=training_device, dtype=dtype),
|
||||
).to(device=ctx.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae" / "model.onnx",
|
||||
|
@ -377,7 +198,7 @@ def convert_diffuser(
|
|||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=training_device, dtype=dtype
|
||||
device=ctx.training_device, dtype=dtype
|
||||
),
|
||||
False,
|
||||
),
|
||||
|
@ -401,7 +222,7 @@ def convert_diffuser(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=training_device, dtype=dtype),
|
||||
).to(device=ctx.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
|
@ -429,9 +250,9 @@ def convert_diffuser(
|
|||
clip_num_channels,
|
||||
clip_image_size,
|
||||
clip_image_size,
|
||||
).to(device=training_device, dtype=dtype),
|
||||
).to(device=ctx.training_device, dtype=dtype),
|
||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
||||
device=training_device, dtype=dtype
|
||||
device=ctx.training_device, dtype=dtype
|
||||
),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
|
@ -453,7 +274,7 @@ def convert_diffuser(
|
|||
feature_extractor = None
|
||||
|
||||
if single_vae:
|
||||
onnx_pipeline = StableDiffusionUpscalePipeline(
|
||||
onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
|
||||
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
|
@ -483,7 +304,7 @@ def convert_diffuser(
|
|||
del onnx_pipeline
|
||||
|
||||
if single_vae:
|
||||
_ = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
_ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
output_path, provider="CPUExecutionProvider"
|
||||
)
|
||||
else:
|
||||
|
@ -493,89 +314,3 @@ def convert_diffuser(
|
|||
|
||||
logger.info("ONNX pipeline is loadable")
|
||||
|
||||
|
||||
def load_models(args, models: Models):
|
||||
if args.diffusion:
|
||||
for source in models.get("diffusion"):
|
||||
if source[0] in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
single_vae = "upscaling" in source[0]
|
||||
convert_diffuser(
|
||||
*source, args.opset, args.half, args.token, single_vae=single_vae
|
||||
)
|
||||
|
||||
if args.upscaling:
|
||||
for source in models.get("upscaling"):
|
||||
if source[0] in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
convert_real_esrgan(*source, args.opset)
|
||||
|
||||
if args.correction:
|
||||
for source in models.get("correction"):
|
||||
if source[0] in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
else:
|
||||
convert_gfpgan(*source, args.opset)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = ArgumentParser(
|
||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||
)
|
||||
|
||||
# model groups
|
||||
parser.add_argument("--correction", action="store_true", default=False)
|
||||
parser.add_argument("--diffusion", action="store_true", default=False)
|
||||
parser.add_argument("--upscaling", action="store_true", default=False)
|
||||
|
||||
# extra models
|
||||
parser.add_argument("--extras", nargs="*", type=str, default=[])
|
||||
parser.add_argument("--skip", nargs="*", type=str, default=[])
|
||||
|
||||
# export options
|
||||
parser.add_argument(
|
||||
"--half",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Export models for half precision, faster on some Nvidia cards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
default=14,
|
||||
type=int,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
help="HuggingFace token with read permissions for downloading models.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
if not path.exists(model_path):
|
||||
logger.info("Model path does not existing, creating: %s", model_path)
|
||||
makedirs(model_path)
|
||||
|
||||
logger.info("Converting base models.")
|
||||
load_models(args, base_models)
|
||||
|
||||
for file in args.extras:
|
||||
if file is not None and file != "":
|
||||
logger.info("Loading extra models from %s", file)
|
||||
try:
|
||||
with open(file, "r") as f:
|
||||
data = loads(f.read())
|
||||
logger.info("Converting extra models.")
|
||||
load_models(args, data)
|
||||
except Exception as err:
|
||||
logger.error("Error converting extra models: %s", err)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
|
@ -0,0 +1,68 @@
|
|||
import torch
|
||||
from shutil import copyfile
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.onnx import export
|
||||
from os import path
|
||||
from logging import getLogger
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from .utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(ctx.model_path, name + ".pth")
|
||||
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"])
|
||||
else:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
|
||||
model.to(ctx.training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
|
||||
input_names = ["data"]
|
||||
output_names = ["output"]
|
||||
dynamic_axes = {
|
||||
"data": {2: "width", 3: "height"},
|
||||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest_onnx,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("Real ESRGAN exported to ONNX successfully.")
|
|
@ -0,0 +1,7 @@
|
|||
import torch
|
||||
|
||||
class ConversionContext:
|
||||
def __init__(self, model_path: str, device: str) -> None:
|
||||
self.model_path = model_path
|
||||
self.training_device = device
|
||||
self.map_location = torch.device(device)
|
|
@ -96,7 +96,7 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
if device is not None and hasattr(pipe, "to"):
|
||||
pipe = pipe.to(device)
|
||||
pipe = pipe.to(device.torch_device())
|
||||
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "images"
|
||||
cond_stage_key: "input_ids"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: true # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
Loading…
Reference in New Issue