1
0
Fork 0

feat(api): add a way to download models from civitai or other https sources (#117)

This commit is contained in:
Sean Sube 2023-02-10 22:41:24 -06:00
parent b3e4076775
commit 9f202486c2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
12 changed files with 489 additions and 366 deletions

View File

@ -1,9 +1,23 @@
{
"diffusion": [
["diffusion-knollingcase", "Aybeeceedee/knollingcase"],
["diffusion-openjourney", "prompthero/openjourney"],
["diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"],
["diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"]
{
"name": "diffusion-knollingcase",
"source": "Aybeeceedee/knollingcase"
},
{
"name": "diffusion-openjourney",
"source": "prompthero/openjourney"
},
{
"name": "diffusion-stablydiffused-aesthetic-v2-6",
"source": "civitai://6266?type=Pruned%20Model&format=SafeTensor",
"format": "safetensors"
},
{
"name": "diffusion-unstable-ink-dream-v6",
"source": "civitai://5796",
"format": "safetensors"
}
],
"correction": [],
"upscaling": []

View File

@ -1,9 +1,3 @@
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext
import warnings
from argparse import ArgumentParser
from json import loads
@ -11,9 +5,17 @@ from logging import getLogger
from os import environ, makedirs, path
from sys import exit
from typing import Dict, List, Optional, Tuple
from yaml import safe_load
from jsonschema import validate, ValidationError
import torch
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings(
"ignore", ".*The shape inference of prim::Constant type is missing.*"
@ -29,20 +31,39 @@ Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
logger = getLogger(__name__)
model_sources: Dict[str, Tuple[str, str]] = {
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
}
model_source_huggingface = "huggingface://"
# recommended models
base_models: Models = {
"diffusion": [
# v1.x
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
# v2.x
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
(
"stable-diffusion-onnx-v2-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
"stable-diffusion-onnx-v1-5",
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
),
# (
# "stable-diffusion-onnx-v1-inpainting",
# model_source_huggingface + "runwayml/stable-diffusion-inpainting",
# ),
# v2.x
# (
# "stable-diffusion-onnx-v2-1",
# model_source_huggingface + "stabilityai/stable-diffusion-2-1",
# ),
# (
# "stable-diffusion-onnx-v2-inpainting",
# model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting",
# ),
# TODO: should have its own converter
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
(
"upscaling-stable-diffusion-x4",
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
True,
),
],
"correction": [
(
@ -79,35 +100,86 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(args, ctx: ConversionContext, models: Models):
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
cache_name = path.join(ctx.cache_path, name)
if format is not None:
# add an extension if possible, some of the conversion code checks for it
cache_name = "%s.%s" % (cache_name, format)
for proto in model_sources:
api_name, api_root = model_sources.get(proto)
if source.startswith(proto):
api_source = api_root % (source.removeprefix(proto))
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
return download_progress([(api_source, cache_name)])
if source.startswith(model_source_huggingface):
hub_source = source.removeprefix(model_source_huggingface)
logger.info("Downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
return hub_source
elif source.startswith("https://"):
logger.info("Downloading model from: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith("http://"):
logger.warning("Downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith(path.sep) or source.startswith("."):
logger.info("Using local model: %s", source)
return source
else:
logger.info("Unknown model location, using path as provided: %s", source)
return source
def convert_models(ctx: ConversionContext, args, models: Models):
if args.diffusion:
for source in models.get("diffusion"):
name, file = source
for model in models.get("diffusion"):
model = tuple_to_diffusion(model)
name = model.get("name")
if name in args.skip:
logger.info("Skipping model: %s", source[0])
logger.info("Skipping model: %s", name)
else:
if file.endswith(".safetensors") or file.endswith(".ckpt"):
convert_diffusion_original(ctx, *source, args.opset, args.half)
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
if format in ["safetensors", "ckpt"]:
convert_diffusion_original(
ctx,
model,
source,
)
else:
# TODO: make this a parameter in the JSON/dict
single_vae = "upscaling" in source[0]
convert_diffusion_stable(
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae
ctx,
model,
source,
)
if args.upscaling:
for source in models.get("upscaling"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
for model in models.get("upscaling"):
model = tuple_to_upscaling(model)
name = model.get("name")
if name in args.skip:
logger.info("Skipping model: %s", name)
else:
convert_upscale_resrgan(ctx, *source, args.opset)
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
convert_upscale_resrgan(ctx, model, source)
if args.correction:
for source in models.get("correction"):
if source[0] in args.skip:
logger.info("Skipping model: %s", source[0])
for model in models.get("correction"):
model = tuple_to_correction(model)
name = model.get("name")
if name in args.skip:
logger.info("Skipping model: %s", name)
else:
convert_correction_gfpgan(ctx, *source, args.opset)
format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
convert_correction_gfpgan(ctx, model, source)
def main() -> int:
@ -146,7 +218,7 @@ def main() -> int:
args = parser.parse_args()
logger.info("CLI arguments: %s", args)
ctx = ConversionContext(model_path, training_device)
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if not path.exists(model_path):
@ -154,16 +226,26 @@ def main() -> int:
makedirs(model_path)
logger.info("Converting base models.")
load_models(args, ctx, base_models)
convert_models(ctx, args, base_models)
for file in args.extras:
if file is not None and file != "":
logger.info("Loading extra models from %s", file)
try:
with open(file, "r") as f:
data = loads(f.read())
data = safe_load(f.read())
with open("./schemas/extras.yaml", "r") as f:
schema = safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
try:
validate(data, schema)
logger.info("Converting extra models.")
load_models(args, ctx, data)
convert_models(ctx, args, data)
except ValidationError as err:
logger.error("Invalid data in extras file: %s", err)
except Exception as err:
logger.error("Error converting extra models: %s", err)

View File

@ -1,31 +1,34 @@
import torch
from logging import getLogger
from os import path
from shutil import copyfile
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export
from os import path
from logging import getLogger
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils import ConversionContext
from .utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@torch.no_grad()
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
dest_path = path.join(ctx.model_path, name + ".pth")
dest_onnx = path.join(ctx.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx):
@torch.no_grad()
def convert_correction_gfpgan(
ctx: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")
dest = path.join(ctx.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
@ -36,7 +39,7 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
scale=scale,
)
torch_model = torch.load(dest_path, map_location=ctx.map_location)
torch_model = torch.load(source, map_location=ctx.map_location)
# TODO: make sure strict=False is safe here
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
@ -54,15 +57,15 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest_onnx)
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest_onnx,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
opset_version=ctx.opset,
export_params=True,
)
logger.info("GFPGAN exported to ONNX successfully.")

View File

@ -11,6 +11,17 @@
# TODO: ask about license before merging
###
import json
import os
import re
import shutil
import traceback
from logging import getLogger
from typing import Dict, List
import huggingface_hub.utils.tqdm
import safetensors.torch
import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
@ -20,95 +31,34 @@ from diffusers import (
HeunDiscreteScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PaintByExamplePipeline,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel, PaintByExamplePipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig,
LDMBertModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download
from logging import getLogger
from omegaconf import OmegaConf
from pydantic import BaseModel
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
from typing import Dict, List, Union
import huggingface_hub.utils.tqdm
import json
import os
import re
import safetensors.torch
import shutil
import torch
import traceback
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionConfig,
)
from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext
from .utils import ConversionContext, ModelDict
logger = getLogger(__name__)
def get_images():
return []
class Concept(BaseModel):
class_data_dir: str = ""
class_guidance_scale: float = 7.5
class_infer_steps: int = 60
class_negative_prompt: str = ""
class_prompt: str = ""
class_token: str = ""
instance_data_dir: str = ""
instance_prompt: str = ""
instance_token: str = ""
is_valid: bool = False
n_save_sample: int = 1
num_class_images: int = 0
num_class_images_per: int = 0
sample_seed: int = -1
save_guidance_scale: float = 7.5
save_infer_steps: int = 60
save_sample_negative_prompt: str = ""
save_sample_prompt: str = ""
save_sample_template: str = ""
def __init__(self, input_dict: Union[Dict, None] = None, **kwargs):
super().__init__(**kwargs)
if input_dict is not None:
self.load_params(input_dict)
if self.is_valid and self.num_class_images != 0:
if self.num_class_images_per == 0:
images = get_images(self.instance_data_dir)
if len(images) < self.num_class_images * 2:
self.num_class_images_per = 1
else:
self.num_class_images_per = self.num_class_images // len(images)
self.num_class_images = 0
def to_dict(self):
return self.dict()
def to_json(self):
return json.dumps(self.to_dict())
def load_params(self, params_dict):
for key, value in params_dict.items():
if hasattr(self, key):
setattr(self, key, value)
if self.instance_data_dir:
self.is_valid = os.path.isdir(self.instance_data_dir)
else:
self.is_valid = False
# Keys to save, replacing our dumb __init__ method
save_keys = []
# Keys to return to the ui when Load Settings is clicked.
ui_keys = []
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in "._- "))
@ -200,7 +150,7 @@ class DreamboothConfig(BaseModel):
super().__init__(**kwargs)
model_name = sanitize_name(model_name)
model_dir = os.path.join(ctx.model_path, model_name)
model_dir = os.path.join(ctx.cache_path, model_name)
working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir):
@ -214,7 +164,6 @@ class DreamboothConfig(BaseModel):
self.scheduler = scheduler
self.v2 = v2
# Actually save as a file
def save(self, backup=False):
"""
Save the config file
@ -236,132 +185,6 @@ class DreamboothConfig(BaseModel):
if hasattr(self, key):
setattr(self, key, value)
# Pass a dict and return a list of Concept objects
def concepts(self, required: int = -1):
concepts = []
c_idx = 0
# If using a file for concepts and not requesting from UI, load from file
if self.use_concepts and self.concepts_path and required == -1:
concepts_list = concepts_from_file(self.concepts_path)
# Otherwise, use 'stored' list
else:
concepts_list = self.concepts_list
if required == -1:
required = len(concepts_list)
for concept_dict in concepts_list:
concept = Concept(input_dict=concept_dict)
if concept.is_valid:
if concept.class_data_dir == "" or concept.class_data_dir is None:
concept.class_data_dir = os.path.join(self.model_dir, f"classifiers_{c_idx}")
concepts.append(concept)
c_idx += 1
missing = len(concepts) - required
if missing > 0:
concepts.extend([Concept(None)] * missing)
return concepts
# Set default values
def check_defaults(self):
if self.model_name is not None and self.model_name != "":
if self.revision == "" or self.revision is None:
self.revision = 0
if self.epoch == "" or self.epoch is None:
self.epoch = 0
self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- "))
models_path = "." # TODO: use ctx path
model_dir = os.path.join(models_path, self.model_name)
working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir):
os.makedirs(working_dir)
self.model_dir = model_dir
self.pretrained_model_name_or_path = working_dir
def concepts_from_file(concepts_path: str):
concepts = []
if os.path.exists(concepts_path) and os.path.isfile(str):
try:
with open(concepts_path,"r") as concepts_file:
concepts_str = concepts_file.read()
except Exception as e:
print(f"Exception opening concepts file: {e}")
else:
concepts_str = concepts_path
try:
concepts_data = json.loads(concepts_str)
for concept_data in concepts_data:
concept = Concept(input_dict=concept_data)
if concept.is_valid:
concepts.append(concept.__dict__)
except Exception as e:
print(f"Exception parsing concepts: {e}")
return concepts
def save_config(*args):
raise Exception("where tho")
params = list(args)
concept_keys = ["c1_", "c2_", "c3_", "c4_"]
model_name = params[0]
if model_name is None or model_name == "":
print("Invalid model name.")
return
config = from_file(ctx, model_name)
if config is None:
config = DreamboothConfig(model_name)
params_dict = dict(zip(save_keys, params))
concepts_list = []
# If using a concepts file/string, keep concepts_list empty.
if params_dict["db_use_concepts"] and params_dict["db_concepts_path"]:
concepts_list = []
params_dict["concepts_list"] = concepts_list
else:
for concept_key in concept_keys:
concept_dict = {}
for key, param in params_dict.items():
if concept_key in key and param is not None:
concept_dict[key.replace(concept_key, "")] = param
concept_test = Concept(concept_dict)
if concept_test.is_valid:
concepts_list.append(concept_test.__dict__)
existing_concepts = params_dict["concepts_list"] if "concepts_list" in params_dict else []
if len(concepts_list) and not len(existing_concepts):
params_dict["concepts_list"] = concepts_list
config.load_params(params_dict)
config.save()
def from_file(ctx: ConversionContext, model_name):
"""
Load config data from UI
Args:
model_name: The config to load
Returns: Dict | None
"""
if model_name == "" or model_name is None:
return None
model_name = sanitize_name(model_name)
config_file = os.path.join(ctx.model_path, model_name, "db_config.json")
try:
with open(config_file, 'r') as openfile:
config_dict = json.load(openfile)
config = DreamboothConfig(model_name)
config.load_params(config_dict)
return config
except Exception as e:
print(f"Exception loading config: {e}")
traceback.print_exc()
return None
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
@ -379,8 +202,6 @@ def from_file(ctx: ConversionContext, model_name):
# limitations under the License.
""" Conversion script for the LDM checkpoints. """
def get_db_models():
return []
def shave_segments(path, n_shave_prefix_segments=1):
"""
@ -1075,7 +896,7 @@ def convert_open_clip_checkpoint(checkpoint):
if 'cond_stage_model.model.text_projection' in checkpoint:
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
else:
print("No projection shape found, setting to 1024")
logger.debug("No projection shape found, setting to 1024")
d_model = 1024
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
@ -1130,7 +951,7 @@ def replace_symlinks(path, base):
blob_path = None
if blob_path is None:
print("NO BLOB")
logger.debug("NO BLOB")
return
os.replace(blob_path, path)
elif os.path.isdir(path):
@ -1140,7 +961,6 @@ def replace_symlinks(path, base):
def download_model(db_config: DreamboothConfig, token):
tmp_dir = os.path.join(db_config.model_dir, "src")
working_dir = db_config.pretrained_model_name_or_path
hub_url = db_config.src
if "http" in hub_url or "huggingface.co" in hub_url:
@ -1155,7 +975,7 @@ def download_model(db_config: DreamboothConfig, token):
)
if repo_info.sha is None:
print("Unable to fetch repo?")
logger.warning("Unable to fetch repo?")
return None, None
siblings = repo_info.siblings
@ -1208,10 +1028,10 @@ def download_model(db_config: DreamboothConfig, token):
if files_to_fetch and config_file:
files_to_fetch.append(config_file)
print(f"Fetching files: {files_to_fetch}")
logger.info(f"Fetching files: {files_to_fetch}")
if not len(files_to_fetch):
print("Nothing to fetch!")
logger.debug("Nothing to fetch!")
return None, None
@ -1296,9 +1116,6 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
print("Could not find valid config. Returning default v1 config.")
return get_config_path(model_versions["v1"], train_types["default"], config_base_name, prediction_type="epsilon")
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
@ -1352,7 +1169,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if db_config is not None:
original_config_file = config
if model_info is not None:
print("Got model info.")
logger.debug("Got model info.")
if ".ckpt" in model_info or ".safetensors" in model_info:
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
from_hub = False
@ -1360,28 +1177,26 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
checkpoint_file = model_info
else:
msg = "Unable to fetch model from hub."
print(msg)
logger.warning(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg
reset_safe = False
try:
checkpoint = None
map_location = torch.device("cpu")
# Try to determine if v1 or v2 model if we have a ckpt
if not from_hub:
print("Loading model from checkpoint.")
logger.info("Loading model from checkpoint.")
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
os.environ["SAFETENSORS_FAST_GPU"] = "1"
try:
print("Loading safetensors...")
logger.debug("Loading safetensors...")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
checkpoint = torch.jit.load(checkpoint_file)
else:
print("Loading ckpt...")
logger.debug("Loading ckpt...")
checkpoint = torch.load(checkpoint_file, map_location=map_location)
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
@ -1401,7 +1216,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not is_512:
# v2.1 needs to upcast attention
print("Setting upcast_attention")
logger.debug("Setting upcast_attention")
upcast_attention = True
v2 = True
else:
@ -1410,15 +1225,15 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
try:
unet = UNet2DConditionModel.from_pretrained(unet_dir)
print("Loaded unet.")
logger.debug("Loaded unet.")
unet_dict = unet.state_dict()
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
print("We got v2!")
logger.debug("UNet using v2 parameters.")
v2 = True
except:
print("Exception loading unet!")
logger.error("Exception loading unet!")
traceback.print_exc()
if v2 and not is_512:
@ -1428,7 +1243,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
print(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
db_config.resolution = image_size
db_config.lifetime_revision = revision
db_config.epoch = epoch
@ -1438,7 +1253,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
db_config.save()
return
print(f"{'v2' if v2 else 'v1'} model loaded.")
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
# Use existing YAML if present
if checkpoint_file is not None:
@ -1447,10 +1262,10 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
original_config_file = config_check
if original_config_file is None or not os.path.exists(original_config_file):
print("Unable to select a config file: %s" % (original_config_file))
logger.warning("Unable to select a config file: %s" % (original_config_file))
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
print(f"Trying to load: {original_config_file}")
logger.debug(f"Trying to load: {original_config_file}")
original_config = OmegaConf.load(original_config_file)
num_train_timesteps = original_config.model.params.timesteps
@ -1489,7 +1304,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
print("Converting unet...")
logger.info("Converting UNet...")
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
@ -1501,14 +1316,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
db_config.has_ema = has_ema
db_config.save()
unet.load_state_dict(converted_unet_checkpoint)
print("Converting vae...")
logger.info("Converting VAE...")
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
print("Converting text encoder...")
logger.info("Converting text encoder...")
# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
@ -1557,17 +1374,17 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler)
except Exception as e:
print(f"Exception setting up output: {e}")
logger.error(f"Exception setting up output: {e}")
pipe = None
traceback.print_exc()
if pipe is None or db_config is None:
msg = "Pipeline or config is not set, unable to continue."
print(msg)
logger.error(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg
else:
resolution = db_config.resolution
print("Saving diffusion model...")
logger.info("Saving diffusion model...")
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
model_dir = db_config.model_dir
@ -1576,12 +1393,12 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
src = db_config.src
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file):
logger.warn("copying original config: %s -> %s", original_config_file, db_config.model_dir)
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
shutil.copy(original_config_file, db_config.model_dir)
basename = os.path.basename(original_config_file)
new_ex_path = os.path.join(db_config.model_dir, basename)
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
logger.warn("copying model config to new name: %s -> %s", new_ex_path, new_name)
logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name)
if os.path.exists(new_name):
os.remove(new_name)
os.rename(new_ex_path, new_name)
@ -1601,27 +1418,35 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
os.makedirs(rem_dir)
print(result_status)
logger.info(result_status)
return
def convert_diffusion_original(ctx: ConversionContext, model_name: str, tensor_file: str, opset: int, half: bool):
model_path = os.path.join(ctx.model_path, model_name)
torch_name = model_name.replace("onnx", "torch")
torch_path = os.path.join(ctx.model_path, torch_name)
working_name = os.path.join(ctx.model_path, torch_name, "working")
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", model_name, tensor_file, model_path)
def convert_diffusion_original(
ctx: ConversionContext,
model: ModelDict,
source: str,
):
name = model["name"]
source = source or model["source"]
if os.path.exists(model_path):
dest = os.path.join(ctx.model_path, name)
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", name, source, dest)
if os.path.exists(dest):
logger.info("ONNX pipeline already exists, skipping.")
return
torch_name = name + "-torch"
torch_path = os.path.join(ctx.cache_path, torch_name)
working_name = os.path.join(ctx.cache_path, torch_name, "working")
if os.path.exists(torch_path):
logger.info("Torch pipeline already exists, reusing.")
else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", tensor_file, torch_path)
extract_checkpoint(ctx, torch_name, tensor_file, from_hub=False)
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source, from_hub=False)
logger.info("Converted original Diffusers checkpoint to Torch model.")
convert_diffusion_stable(ctx, model_path, working_name, opset, half, None)
logger.info("ONNX pipeline saved to %s", model_name)
convert_diffusion_stable(ctx, model, working_name)
logger.info("ONNX pipeline saved to %s", name)

View File

@ -1,19 +1,22 @@
from logging import getLogger
from os import mkdir, path
from pathlib import Path
from shutil import rmtree
from typing import Dict
import torch
from diffusers import (
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionPipeline,
)
from torch.onnx import export
from logging import getLogger
from shutil import rmtree
import torch
from os import path, mkdir
from pathlib import Path
from onnx import load, save_model
from torch.onnx import export
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from .utils import ConversionContext
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
logger = getLogger(__name__)
@ -48,21 +51,21 @@ def onnx_export(
@torch.no_grad()
def convert_diffusion_stable(
ctx: ConversionContext,
name: str,
url: str,
opset: int,
half: bool,
token: str,
single_vae: bool = False,
model: Dict,
source: str,
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
dtype = torch.float16 if half else torch.float32
name = model.get("name")
source = source or model.get("source")
single_vae = model.get("single_vae")
dtype = torch.float16 if ctx.half else torch.float32
dest_path = path.join(ctx.model_path, name)
# diffusers go into a directory rather than .onnx file
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path)
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path)
if single_vae:
logger.info("converting model with single VAE")
@ -71,13 +74,16 @@ def convert_diffusion_stable(
logger.info("ONNX model already exists, skipping.")
return
if half and ctx.training_device != "cuda":
if ctx.half and ctx.training_device != "cuda":
raise ValueError(
"Half precision model export is only supported on GPUs with CUDA"
)
pipeline = StableDiffusionPipeline.from_pretrained(
url, torch_dtype=dtype, use_auth_token=token
source,
torch_dtype=dtype,
use_auth_token=ctx.token,
# cache_dir=path.join(ctx.cache_path, name)
).to(ctx.training_device)
output_path = Path(dest_path)
@ -94,14 +100,16 @@ def convert_diffusion_stable(
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)),
model_args=(
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)
),
output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=opset,
opset=ctx.opset,
)
del pipeline.text_encoder
@ -113,7 +121,9 @@ def convert_diffusion_stable(
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool)
unet_scale = torch.tensor(False).to(
device=ctx.training_device, dtype=torch.bool
)
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
@ -139,7 +149,7 @@ def convert_diffusion_stable(
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=opset,
opset=ctx.opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)
unet_model_path = str(unet_path.absolute().as_posix())
@ -182,7 +192,7 @@ def convert_diffusion_stable(
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
opset=ctx.opset,
)
else:
# VAE ENCODER
@ -207,7 +217,7 @@ def convert_diffusion_stable(
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
opset=ctx.opset,
)
# VAE DECODER
@ -230,7 +240,7 @@ def convert_diffusion_stable(
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
opset=ctx.opset,
)
del pipeline.vae
@ -261,7 +271,7 @@ def convert_diffusion_stable(
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
},
opset=opset,
opset=ctx.opset,
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(
@ -312,4 +322,3 @@ def convert_diffusion_stable(
)
logger.info("ONNX pipeline is loadable")

View File

@ -1,32 +1,34 @@
import torch
from logging import getLogger
from os import path
from shutil import copyfile
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export
from os import path
from logging import getLogger
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils import ConversionContext
from .utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@torch.no_grad()
def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
dest_path = path.join(ctx.model_path, name + ".pth")
dest_onnx = path.join(ctx.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
def convert_upscale_resrgan(
ctx: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")
if path.isfile(dest_onnx):
dest = path.join(ctx.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping.")
return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
@ -37,7 +39,7 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
scale=scale,
)
torch_model = torch.load(dest_path, map_location=ctx.map_location)
torch_model = torch.load(source, map_location=ctx.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
else:
@ -54,15 +56,15 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest_onnx)
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest_onnx,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset,
opset_version=ctx.opset,
export_params=True,
)
logger.info("Real ESRGAN exported to ONNX successfully.")

View File

@ -1,7 +1,129 @@
import shutil
from functools import partial
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict, Union, List, Optional, Tuple
import requests
import torch
from tqdm.auto import tqdm
logger = getLogger(__name__)
ModelDict = Dict[str, Union[str, int]]
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext:
def __init__(self, model_path: str, device: str) -> None:
def __init__(
self,
model_path: str,
device: str,
cache_path: Optional[str] = None,
half: Optional[bool] = False,
opset: Optional[int] = None,
token: Optional[str] = None,
) -> None:
self.model_path = model_path
self.cache_path = cache_path or path.join(model_path, ".cache")
self.training_device = device
self.map_location = torch.device(device)
self.half = half
self.opset = opset
self.token = token
def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls:
dest_path = Path(dest).expanduser().resolve()
dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists():
logger.info("Destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(url, stream=True, allow_redirects=True)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RuntimeError(
"Request to %s failed with status code: %s" % (url, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
return str(dest_path.absolute())
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
return {
"name": name,
"source": source,
"half": half,
"opset": opset,
"scale": scale,
}
else:
return model
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
single_vae = rest[0] if len(rest) > 0 else False
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
return {
"name": name,
"source": source,
"half": half,
"opset": opset,
"single_vae": single_vae,
}
else:
return model
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
return {
"name": name,
"source": source,
"half": half,
"opset": opset,
"scale": scale,
}
else:
return model
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
def source_format(model: Dict) -> Optional[str]:
if "format" in model:
return model["format"]
if "source" in model:
ext = path.splitext(model["source"])
if ext in known_formats:
return ext
return None

View File

@ -649,10 +649,10 @@ def chain():
return error_reply("chain pipeline must have a body")
data = yaml.safe_load(body)
with open("./schema.yaml", "r") as f:
with open("./schemas/chain.yaml", "r") as f:
schema = yaml.safe_load(f.read())
logger.info("validating chain request: %s against %s", data, schema)
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters

View File

@ -1,5 +1,5 @@
[tool.black]
force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py'''
force-exclude = '''/(diffusion_original|lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py'''
[tool.isort]
profile = "black"

View File

@ -1,4 +1,4 @@
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml
$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/chain.yaml
$schema: https://json-schema.org/draft/2020-12/schema
$defs:

65
api/schemas/extras.yaml Normal file
View File

@ -0,0 +1,65 @@
$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/extras.yaml
$schema: https://json-schema.org/draft/2020-12/schema
$defs:
legacy_tuple:
type: array
items:
oneOf:
- type: string
- type: number
base_model:
type: object
required: [name, source]
properties:
format:
type: string
enum: [onnx, pth, ckpt, safetensors]
half:
type: boolean
name:
type: string
opset:
type: number
source:
type: string
correction_model:
allOf:
- $ref: "#/$defs/base_model"
diffusion_model:
allOf:
- $ref: "#/$defs/base_model"
upscaling_model:
allOf:
- $ref: "#/$defs/base_model"
- type: object
required: [scale]
properties:
scale:
type: number
type: object
additionalProperties: False
properties:
diffusion:
type: array
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/diffusion_model"
correction:
type: array
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/correction_model"
upscaling:
type: array
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/upscaling_model"

View File

@ -58,6 +58,7 @@
"rocm",
"RRDB",
"runwayml",
"safetensors",
"scandir",
"scipy",
"scrollback",