1
0
Fork 0

feat(api): convert from SD checkpoints (#117)

This commit is contained in:
Sean Sube 2023-02-08 22:35:54 -06:00
parent cd4a0f10b0
commit 4f71348f98
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 2050 additions and 298 deletions

View File

@ -0,0 +1,177 @@
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext
import warnings
from argparse import ArgumentParser
from json import loads
from logging import getLogger
from os import environ, makedirs, path
from sys import exit
from typing import Dict, List, Optional, Tuple
import torch
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings(
"ignore", ".*The shape inference of prim::Constant type is missing.*"
)
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*")
warnings.filterwarnings(
"ignore",
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
logger = getLogger(__name__)
# recommended models
base_models: Models = {
"diffusion": [
# v1.x
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
# v2.x
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
(
"stable-diffusion-onnx-v2-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
),
# TODO: should have its own converter
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
# TODO: testing safetensors
("diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"),
("diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"),
],
"correction": [
(
"correction-gfpgan-v1-3",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
4,
),
(
"correction-codeformer",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
1,
),
],
"upscaling": [
(
"upscaling-real-esrgan-x2-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
2,
),
(
"upscaling-real-esrgan-x4-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
4,
),
(
"upscaling-real-esrgan-x4-v3",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
4,
),
],
}
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(args, ctx: ConversionContext, models: Models):
if args.diffusion:
for source in models.get("diffusion"):
name, file = source
if name in args.skip:
logger.info("Skipping model: %s", source[0])
else:
if file.endswith(".safetensors") or file.endswith(".ckpt"):
convert_diffusion_original(ctx, *source, args.opset, args.half)
else:
# TODO: make this a parameter in the JSON/dict
single_vae = "upscaling" in source[0]
convert_diffusion_stable(
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae
)
if args.upscaling:
for source in models.get("upscaling"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
else:
convert_upscale_resrgan(ctx, *source, args.opset)
if args.correction:
for source in models.get("correction"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
else:
convert_correction_gfpgan(ctx, *source, args.opset)
def main() -> int:
parser = ArgumentParser(
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
)
# model groups
parser.add_argument("--correction", action="store_true", default=False)
parser.add_argument("--diffusion", action="store_true", default=False)
parser.add_argument("--upscaling", action="store_true", default=False)
# extra models
parser.add_argument("--extras", nargs="*", type=str, default=[])
parser.add_argument("--skip", nargs="*", type=str, default=[])
# export options
parser.add_argument(
"--half",
action="store_true",
default=False,
help="Export models for half precision, faster on some Nvidia cards.",
)
parser.add_argument(
"--opset",
default=14,
type=int,
help="The version of the ONNX operator set to use.",
)
parser.add_argument(
"--token",
type=str,
help="HuggingFace token with read permissions for downloading models.",
)
args = parser.parse_args()
logger.info("CLI arguments: %s", args)
ctx = ConversionContext(model_path, training_device)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if not path.exists(model_path):
logger.info("Model path does not existing, creating: %s", model_path)
makedirs(model_path)
logger.info("Converting base models.")
load_models(args, ctx, base_models)
for file in args.extras:
if file is not None and file != "":
logger.info("Loading extra models from %s", file)
try:
with open(file, "r") as f:
data = loads(f.read())
logger.info("Converting extra models.")
load_models(args, ctx, data)
except Exception as err:
logger.error("Error converting extra models: %s", err)
return 0
if __name__ == "__main__":
exit(main())

View File

@ -0,0 +1,68 @@
import torch
from shutil import copyfile
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export
from os import path
from logging import getLogger
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils import ConversionContext
logger = getLogger(__name__)
@torch.no_grad()
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
dest_path = path.join(ctx.model_path, name + ".pth")
dest_onnx = path.join(ctx.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx):
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=ctx.map_location)
# TODO: make sure strict=False is safe here
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(ctx.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest_onnx)
export(
model,
rng,
dest_onnx,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
export_params=True,
)
logger.info("GFPGAN exported to ONNX successfully.")

File diff suppressed because it is too large Load Diff

View File

@ -1,208 +1,23 @@
import warnings
from argparse import ArgumentParser
from json import loads
from logging import getLogger
from os import environ, makedirs, mkdir, path
from pathlib import Path
from shutil import copyfile, rmtree
from sys import exit
from typing import Dict, List, Optional, Tuple
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from diffusers import ( from diffusers import (
OnnxRuntimeModel, OnnxRuntimeModel,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
) )
from onnx import load, save_model
from torch.onnx import export from torch.onnx import export
from logging import getLogger
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 from shutil import rmtree
warnings.filterwarnings( import torch
"ignore", ".*The shape inference of prim::Constant type is missing.*" from os import path, mkdir
) from pathlib import Path
warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*") from onnx import load, save_model
warnings.filterwarnings(
"ignore",
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Tuple[str, str, Optional[int]]]] from .utils import ConversionContext
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
logger = getLogger(__name__) logger = getLogger(__name__)
# recommended models
base_models: Models = {
"diffusion": [
# v1.x
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
# v2.x
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
(
"stable-diffusion-onnx-v2-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
),
# TODO: should have its own converter
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
],
"correction": [
(
"correction-gfpgan-v1-3",
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth",
4,
),
(
"correction-codeformer",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
1,
),
],
"upscaling": [
(
"upscaling-real-esrgan-x2-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
2,
),
(
"upscaling-real-esrgan-x4-plus",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
4,
),
(
"upscaling-real-esrgan-x4-v3",
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
4,
),
],
}
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu"
map_location = torch.device(training_device)
@torch.no_grad()
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name + ".pth")
dest_onnx = path.join(model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx):
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest_onnx)
export(
model,
rng,
dest_onnx,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
export_params=True,
)
logger.info("Real ESRGAN exported to ONNX successfully.")
@torch.no_grad()
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name + ".pth")
dest_onnx = path.join(model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx):
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=map_location)
# TODO: make sure strict=False is safe here
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest_onnx)
export(
model,
rng,
dest_onnx,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
export_params=True,
)
logger.info("GFPGAN exported to ONNX successfully.")
def onnx_export( def onnx_export(
model, model,
model_args: tuple, model_args: tuple,
@ -231,33 +46,39 @@ def onnx_export(
@torch.no_grad() @torch.no_grad()
def convert_diffuser( def convert_diffusion_stable(
name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False ctx: ConversionContext,
name: str,
url: str,
opset: int,
half: bool,
token: str,
single_vae: bool = False,
): ):
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
dtype = torch.float16 if half else torch.float32 dtype = torch.float16 if half else torch.float32
dest_path = path.join(model_path, name) dest_path = path.join(ctx.model_path, name)
# diffusers go into a directory rather than .onnx file # diffusers go into a directory rather than .onnx file
logger.info("converting Diffusers model: %s -> %s/", name, dest_path) logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path)
if single_vae: if single_vae:
logger.info("converting model with single VAE") logger.info("converting model with single VAE")
if path.isdir(dest_path): if path.exists(dest_path):
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping.")
return return
if half and training_device != "cuda": if half and ctx.training_device != "cuda":
raise ValueError( raise ValueError(
"Half precision model export is only supported on GPUs with CUDA" "Half precision model export is only supported on GPUs with CUDA"
) )
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
url, torch_dtype=dtype, use_auth_token=token url, torch_dtype=dtype, use_auth_token=token
).to(training_device) ).to(ctx.training_device)
output_path = Path(dest_path) output_path = Path(dest_path)
# TEXT ENCODER # TEXT ENCODER
@ -273,7 +94,7 @@ def convert_diffuser(
onnx_export( onnx_export(
pipeline.text_encoder, pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)), model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)),
output_path=output_path / "text_encoder" / "model.onnx", output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"], ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"], output_names=["last_hidden_state", "pooler_output"],
@ -302,11 +123,11 @@ def convert_diffuser(
pipeline.unet, pipeline.unet,
model_args=( model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=training_device, dtype=dtype device=ctx.training_device, dtype=dtype
), ),
torch.randn(2).to(device=training_device, dtype=dtype), torch.randn(2).to(device=ctx.training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to( torch.randn(2, num_tokens, text_hidden_size).to(
device=training_device, dtype=dtype device=ctx.training_device, dtype=dtype
), ),
unet_scale, unet_scale,
), ),
@ -353,7 +174,7 @@ def convert_diffuser(
model_args=( model_args=(
torch.randn( torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size 1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=training_device, dtype=dtype), ).to(device=ctx.training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae" / "model.onnx", output_path=output_path / "vae" / "model.onnx",
@ -377,7 +198,7 @@ def convert_diffuser(
vae_encoder, vae_encoder,
model_args=( model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=training_device, dtype=dtype device=ctx.training_device, dtype=dtype
), ),
False, False,
), ),
@ -401,7 +222,7 @@ def convert_diffuser(
model_args=( model_args=(
torch.randn( torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size 1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=training_device, dtype=dtype), ).to(device=ctx.training_device, dtype=dtype),
False, False,
), ),
output_path=output_path / "vae_decoder" / "model.onnx", output_path=output_path / "vae_decoder" / "model.onnx",
@ -429,9 +250,9 @@ def convert_diffuser(
clip_num_channels, clip_num_channels,
clip_image_size, clip_image_size,
clip_image_size, clip_image_size,
).to(device=training_device, dtype=dtype), ).to(device=ctx.training_device, dtype=dtype),
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to( torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
device=training_device, dtype=dtype device=ctx.training_device, dtype=dtype
), ),
), ),
output_path=output_path / "safety_checker" / "model.onnx", output_path=output_path / "safety_checker" / "model.onnx",
@ -453,7 +274,7 @@ def convert_diffuser(
feature_extractor = None feature_extractor = None
if single_vae: if single_vae:
onnx_pipeline = StableDiffusionUpscalePipeline( onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"), vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer, tokenizer=pipeline.tokenizer,
@ -483,7 +304,7 @@ def convert_diffuser(
del onnx_pipeline del onnx_pipeline
if single_vae: if single_vae:
_ = StableDiffusionUpscalePipeline.from_pretrained( _ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
output_path, provider="CPUExecutionProvider" output_path, provider="CPUExecutionProvider"
) )
else: else:
@ -493,89 +314,3 @@ def convert_diffuser(
logger.info("ONNX pipeline is loadable") logger.info("ONNX pipeline is loadable")
def load_models(args, models: Models):
if args.diffusion:
for source in models.get("diffusion"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
else:
single_vae = "upscaling" in source[0]
convert_diffuser(
*source, args.opset, args.half, args.token, single_vae=single_vae
)
if args.upscaling:
for source in models.get("upscaling"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
else:
convert_real_esrgan(*source, args.opset)
if args.correction:
for source in models.get("correction"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
else:
convert_gfpgan(*source, args.opset)
def main() -> int:
parser = ArgumentParser(
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
)
# model groups
parser.add_argument("--correction", action="store_true", default=False)
parser.add_argument("--diffusion", action="store_true", default=False)
parser.add_argument("--upscaling", action="store_true", default=False)
# extra models
parser.add_argument("--extras", nargs="*", type=str, default=[])
parser.add_argument("--skip", nargs="*", type=str, default=[])
# export options
parser.add_argument(
"--half",
action="store_true",
default=False,
help="Export models for half precision, faster on some Nvidia cards.",
)
parser.add_argument(
"--opset",
default=14,
type=int,
help="The version of the ONNX operator set to use.",
)
parser.add_argument(
"--token",
type=str,
help="HuggingFace token with read permissions for downloading models.",
)
args = parser.parse_args()
logger.info("CLI arguments: %s", args)
if not path.exists(model_path):
logger.info("Model path does not existing, creating: %s", model_path)
makedirs(model_path)
logger.info("Converting base models.")
load_models(args, base_models)
for file in args.extras:
if file is not None and file != "":
logger.info("Loading extra models from %s", file)
try:
with open(file, "r") as f:
data = loads(f.read())
logger.info("Converting extra models.")
load_models(args, data)
except Exception as err:
logger.error("Error converting extra models: %s", err)
return 0
if __name__ == "__main__":
exit(main())

View File

@ -0,0 +1,68 @@
import torch
from shutil import copyfile
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export
from os import path
from logging import getLogger
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils import ConversionContext
logger = getLogger(__name__)
@torch.no_grad()
def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
dest_path = path.join(ctx.model_path, name + ".pth")
dest_onnx = path.join(ctx.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx):
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(dest_path, map_location=ctx.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(ctx.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=ctx.map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest_onnx)
export(
model,
rng,
dest_onnx,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
export_params=True,
)
logger.info("Real ESRGAN exported to ONNX successfully.")

View File

@ -0,0 +1,7 @@
import torch
class ConversionContext:
def __init__(self, model_path: str, device: str) -> None:
self.model_path = model_path
self.training_device = device
self.map_location = torch.device(device)

View File

@ -96,7 +96,7 @@ def load_pipeline(
) )
if device is not None and hasattr(pipe, "to"): if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device) pipe = pipe.to(device.torch_device())
last_pipeline_instance = pipe last_pipeline_instance = pipe
last_pipeline_options = options last_pipeline_options = options

View File

@ -0,0 +1,70 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "images"
cond_stage_key: "input_ids"
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder