1
0
Fork 0

feat(api): add a way to download models from civitai or other https sources (#117)

This commit is contained in:
Sean Sube 2023-02-10 22:41:24 -06:00
parent b3e4076775
commit 9f202486c2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
12 changed files with 489 additions and 366 deletions

View File

@ -1,9 +1,23 @@
{ {
"diffusion": [ "diffusion": [
["diffusion-knollingcase", "Aybeeceedee/knollingcase"], {
["diffusion-openjourney", "prompthero/openjourney"], "name": "diffusion-knollingcase",
["diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"], "source": "Aybeeceedee/knollingcase"
["diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"] },
{
"name": "diffusion-openjourney",
"source": "prompthero/openjourney"
},
{
"name": "diffusion-stablydiffused-aesthetic-v2-6",
"source": "civitai://6266?type=Pruned%20Model&format=SafeTensor",
"format": "safetensors"
},
{
"name": "diffusion-unstable-ink-dream-v6",
"source": "civitai://5796",
"format": "safetensors"
}
], ],
"correction": [], "correction": [],
"upscaling": [] "upscaling": []

View File

@ -1,9 +1,3 @@
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext
import warnings import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from json import loads from json import loads
@ -11,9 +5,17 @@ from logging import getLogger
from os import environ, makedirs, path from os import environ, makedirs, path
from sys import exit from sys import exit
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from yaml import safe_load
from jsonschema import validate, ValidationError
import torch import torch
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 # suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings( warnings.filterwarnings(
"ignore", ".*The shape inference of prim::Constant type is missing.*" "ignore", ".*The shape inference of prim::Constant type is missing.*"
@ -29,20 +31,39 @@ Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
logger = getLogger(__name__) logger = getLogger(__name__)
model_sources: Dict[str, Tuple[str, str]] = {
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
}
model_source_huggingface = "huggingface://"
# recommended models # recommended models
base_models: Models = { base_models: Models = {
"diffusion": [ "diffusion": [
# v1.x # v1.x
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
# v2.x
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
( (
"stable-diffusion-onnx-v2-inpainting", "stable-diffusion-onnx-v1-5",
"stabilityai/stable-diffusion-2-inpainting", model_source_huggingface + "runwayml/stable-diffusion-v1-5",
), ),
# (
# "stable-diffusion-onnx-v1-inpainting",
# model_source_huggingface + "runwayml/stable-diffusion-inpainting",
# ),
# v2.x
# (
# "stable-diffusion-onnx-v2-1",
# model_source_huggingface + "stabilityai/stable-diffusion-2-1",
# ),
# (
# "stable-diffusion-onnx-v2-inpainting",
# model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting",
# ),
# TODO: should have its own converter # TODO: should have its own converter
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"), (
"upscaling-stable-diffusion-x4",
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
True,
),
], ],
"correction": [ "correction": [
( (
@ -79,35 +100,86 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu" training_device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(args, ctx: ConversionContext, models: Models): def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
cache_name = path.join(ctx.cache_path, name)
if format is not None:
# add an extension if possible, some of the conversion code checks for it
cache_name = "%s.%s" % (cache_name, format)
for proto in model_sources:
api_name, api_root = model_sources.get(proto)
if source.startswith(proto):
api_source = api_root % (source.removeprefix(proto))
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
return download_progress([(api_source, cache_name)])
if source.startswith(model_source_huggingface):
hub_source = source.removeprefix(model_source_huggingface)
logger.info("Downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
return hub_source
elif source.startswith("https://"):
logger.info("Downloading model from: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith("http://"):
logger.warning("Downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)])
elif source.startswith(path.sep) or source.startswith("."):
logger.info("Using local model: %s", source)
return source
else:
logger.info("Unknown model location, using path as provided: %s", source)
return source
def convert_models(ctx: ConversionContext, args, models: Models):
if args.diffusion: if args.diffusion:
for source in models.get("diffusion"): for model in models.get("diffusion"):
name, file = source model = tuple_to_diffusion(model)
name = model.get("name")
if name in args.skip: if name in args.skip:
logger.info("Skipping model: %s", source[0]) logger.info("Skipping model: %s", name)
else: else:
if file.endswith(".safetensors") or file.endswith(".ckpt"): format = source_format(model)
convert_diffusion_original(ctx, *source, args.opset, args.half) source = fetch_model(ctx, name, model["source"], format=format)
if format in ["safetensors", "ckpt"]:
convert_diffusion_original(
ctx,
model,
source,
)
else: else:
# TODO: make this a parameter in the JSON/dict
single_vae = "upscaling" in source[0]
convert_diffusion_stable( convert_diffusion_stable(
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae ctx,
model,
source,
) )
if args.upscaling: if args.upscaling:
for source in models.get("upscaling"): for model in models.get("upscaling"):
if source[0] in args.skip: model = tuple_to_upscaling(model)
logger.info("Skipping model: %s", source[0]) name = model.get("name")
if name in args.skip:
logger.info("Skipping model: %s", name)
else: else:
convert_upscale_resrgan(ctx, *source, args.opset) format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
convert_upscale_resrgan(ctx, model, source)
if args.correction: if args.correction:
for source in models.get("correction"): for model in models.get("correction"):
if source[0] in args.skip: model = tuple_to_correction(model)
logger.info("Skipping model: %s", source[0]) name = model.get("name")
if name in args.skip:
logger.info("Skipping model: %s", name)
else: else:
convert_correction_gfpgan(ctx, *source, args.opset) format = source_format(model)
source = fetch_model(ctx, name, model["source"], format=format)
convert_correction_gfpgan(ctx, model, source)
def main() -> int: def main() -> int:
@ -146,7 +218,7 @@ def main() -> int:
args = parser.parse_args() args = parser.parse_args()
logger.info("CLI arguments: %s", args) logger.info("CLI arguments: %s", args)
ctx = ConversionContext(model_path, training_device) ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if not path.exists(model_path): if not path.exists(model_path):
@ -154,16 +226,26 @@ def main() -> int:
makedirs(model_path) makedirs(model_path)
logger.info("Converting base models.") logger.info("Converting base models.")
load_models(args, ctx, base_models) convert_models(ctx, args, base_models)
for file in args.extras: for file in args.extras:
if file is not None and file != "": if file is not None and file != "":
logger.info("Loading extra models from %s", file) logger.info("Loading extra models from %s", file)
try: try:
with open(file, "r") as f: with open(file, "r") as f:
data = loads(f.read()) data = safe_load(f.read())
with open("./schemas/extras.yaml", "r") as f:
schema = safe_load(f.read())
logger.debug("validating chain request: %s against %s", data, schema)
try:
validate(data, schema)
logger.info("Converting extra models.") logger.info("Converting extra models.")
load_models(args, ctx, data) convert_models(ctx, args, data)
except ValidationError as err:
logger.error("Invalid data in extras file: %s", err)
except Exception as err: except Exception as err:
logger.error("Error converting extra models: %s", err) logger.error("Error converting extra models: %s", err)

View File

@ -1,31 +1,34 @@
import torch from logging import getLogger
from os import path
from shutil import copyfile from shutil import copyfile
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export from torch.onnx import export
from os import path
from logging import getLogger from .utils import ConversionContext, ModelDict
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
@torch.no_grad()
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
dest_path = path.join(ctx.model_path, name + ".pth")
dest_onnx = path.join(ctx.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
if path.isfile(dest_onnx): @torch.no_grad()
def convert_correction_gfpgan(
ctx: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")
dest = path.join(ctx.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping.")
return return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model") logger.info("loading and training model")
model = RRDBNet( model = RRDBNet(
num_in_ch=3, num_in_ch=3,
@ -36,7 +39,7 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
scale=scale, scale=scale,
) )
torch_model = torch.load(dest_path, map_location=ctx.map_location) torch_model = torch.load(source, map_location=ctx.map_location)
# TODO: make sure strict=False is safe here # TODO: make sure strict=False is safe here
if "params_ema" in torch_model: if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False) model.load_state_dict(torch_model["params_ema"], strict=False)
@ -54,15 +57,15 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
"output": {2: "width", 3: "height"}, "output": {2: "width", 3: "height"},
} }
logger.info("exporting ONNX model to %s", dest_onnx) logger.info("exporting ONNX model to %s", dest)
export( export(
model, model,
rng, rng,
dest_onnx, dest,
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
opset_version=opset, opset_version=ctx.opset,
export_params=True, export_params=True,
) )
logger.info("GFPGAN exported to ONNX successfully.") logger.info("GFPGAN exported to ONNX successfully.")

View File

@ -11,6 +11,17 @@
# TODO: ask about license before merging # TODO: ask about license before merging
### ###
import json
import os
import re
import shutil
import traceback
from logging import getLogger
from typing import Dict, List
import huggingface_hub.utils.tqdm
import safetensors.torch
import torch
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
@ -20,95 +31,34 @@ from diffusers import (
HeunDiscreteScheduler, HeunDiscreteScheduler,
LDMTextToImagePipeline, LDMTextToImagePipeline,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PaintByExamplePipeline,
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, PaintByExamplePipeline, UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig,
LDMBertModel,
) )
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from logging import getLogger
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pydantic import BaseModel from pydantic import BaseModel
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig from transformers import (
from typing import Dict, List, Union AutoFeatureExtractor,
BertTokenizerFast,
import huggingface_hub.utils.tqdm CLIPTextModel,
import json CLIPTokenizer,
import os CLIPVisionConfig,
import re )
import safetensors.torch
import shutil
import torch
import traceback
from .diffusion_stable import convert_diffusion_stable from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext from .utils import ConversionContext, ModelDict
logger = getLogger(__name__) logger = getLogger(__name__)
def get_images():
return []
class Concept(BaseModel):
class_data_dir: str = ""
class_guidance_scale: float = 7.5
class_infer_steps: int = 60
class_negative_prompt: str = ""
class_prompt: str = ""
class_token: str = ""
instance_data_dir: str = ""
instance_prompt: str = ""
instance_token: str = ""
is_valid: bool = False
n_save_sample: int = 1
num_class_images: int = 0
num_class_images_per: int = 0
sample_seed: int = -1
save_guidance_scale: float = 7.5
save_infer_steps: int = 60
save_sample_negative_prompt: str = ""
save_sample_prompt: str = ""
save_sample_template: str = ""
def __init__(self, input_dict: Union[Dict, None] = None, **kwargs):
super().__init__(**kwargs)
if input_dict is not None:
self.load_params(input_dict)
if self.is_valid and self.num_class_images != 0:
if self.num_class_images_per == 0:
images = get_images(self.instance_data_dir)
if len(images) < self.num_class_images * 2:
self.num_class_images_per = 1
else:
self.num_class_images_per = self.num_class_images // len(images)
self.num_class_images = 0
def to_dict(self):
return self.dict()
def to_json(self):
return json.dumps(self.to_dict())
def load_params(self, params_dict):
for key, value in params_dict.items():
if hasattr(self, key):
setattr(self, key, value)
if self.instance_data_dir:
self.is_valid = os.path.isdir(self.instance_data_dir)
else:
self.is_valid = False
# Keys to save, replacing our dumb __init__ method
save_keys = []
# Keys to return to the ui when Load Settings is clicked.
ui_keys = []
def sanitize_name(name): def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in "._- ")) return "".join(x for x in name if (x.isalnum() or x in "._- "))
@ -200,7 +150,7 @@ class DreamboothConfig(BaseModel):
super().__init__(**kwargs) super().__init__(**kwargs)
model_name = sanitize_name(model_name) model_name = sanitize_name(model_name)
model_dir = os.path.join(ctx.model_path, model_name) model_dir = os.path.join(ctx.cache_path, model_name)
working_dir = os.path.join(model_dir, "working") working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir): if not os.path.exists(working_dir):
@ -214,7 +164,6 @@ class DreamboothConfig(BaseModel):
self.scheduler = scheduler self.scheduler = scheduler
self.v2 = v2 self.v2 = v2
# Actually save as a file
def save(self, backup=False): def save(self, backup=False):
""" """
Save the config file Save the config file
@ -236,132 +185,6 @@ class DreamboothConfig(BaseModel):
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
# Pass a dict and return a list of Concept objects
def concepts(self, required: int = -1):
concepts = []
c_idx = 0
# If using a file for concepts and not requesting from UI, load from file
if self.use_concepts and self.concepts_path and required == -1:
concepts_list = concepts_from_file(self.concepts_path)
# Otherwise, use 'stored' list
else:
concepts_list = self.concepts_list
if required == -1:
required = len(concepts_list)
for concept_dict in concepts_list:
concept = Concept(input_dict=concept_dict)
if concept.is_valid:
if concept.class_data_dir == "" or concept.class_data_dir is None:
concept.class_data_dir = os.path.join(self.model_dir, f"classifiers_{c_idx}")
concepts.append(concept)
c_idx += 1
missing = len(concepts) - required
if missing > 0:
concepts.extend([Concept(None)] * missing)
return concepts
# Set default values
def check_defaults(self):
if self.model_name is not None and self.model_name != "":
if self.revision == "" or self.revision is None:
self.revision = 0
if self.epoch == "" or self.epoch is None:
self.epoch = 0
self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- "))
models_path = "." # TODO: use ctx path
model_dir = os.path.join(models_path, self.model_name)
working_dir = os.path.join(model_dir, "working")
if not os.path.exists(working_dir):
os.makedirs(working_dir)
self.model_dir = model_dir
self.pretrained_model_name_or_path = working_dir
def concepts_from_file(concepts_path: str):
concepts = []
if os.path.exists(concepts_path) and os.path.isfile(str):
try:
with open(concepts_path,"r") as concepts_file:
concepts_str = concepts_file.read()
except Exception as e:
print(f"Exception opening concepts file: {e}")
else:
concepts_str = concepts_path
try:
concepts_data = json.loads(concepts_str)
for concept_data in concepts_data:
concept = Concept(input_dict=concept_data)
if concept.is_valid:
concepts.append(concept.__dict__)
except Exception as e:
print(f"Exception parsing concepts: {e}")
return concepts
def save_config(*args):
raise Exception("where tho")
params = list(args)
concept_keys = ["c1_", "c2_", "c3_", "c4_"]
model_name = params[0]
if model_name is None or model_name == "":
print("Invalid model name.")
return
config = from_file(ctx, model_name)
if config is None:
config = DreamboothConfig(model_name)
params_dict = dict(zip(save_keys, params))
concepts_list = []
# If using a concepts file/string, keep concepts_list empty.
if params_dict["db_use_concepts"] and params_dict["db_concepts_path"]:
concepts_list = []
params_dict["concepts_list"] = concepts_list
else:
for concept_key in concept_keys:
concept_dict = {}
for key, param in params_dict.items():
if concept_key in key and param is not None:
concept_dict[key.replace(concept_key, "")] = param
concept_test = Concept(concept_dict)
if concept_test.is_valid:
concepts_list.append(concept_test.__dict__)
existing_concepts = params_dict["concepts_list"] if "concepts_list" in params_dict else []
if len(concepts_list) and not len(existing_concepts):
params_dict["concepts_list"] = concepts_list
config.load_params(params_dict)
config.save()
def from_file(ctx: ConversionContext, model_name):
"""
Load config data from UI
Args:
model_name: The config to load
Returns: Dict | None
"""
if model_name == "" or model_name is None:
return None
model_name = sanitize_name(model_name)
config_file = os.path.join(ctx.model_path, model_name, "db_config.json")
try:
with open(config_file, 'r') as openfile:
config_dict = json.load(openfile)
config = DreamboothConfig(model_name)
config.load_params(config_dict)
return config
except Exception as e:
print(f"Exception loading config: {e}")
traceback.print_exc()
return None
# coding=utf-8 # coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. # Copyright 2022 The HuggingFace Inc. team.
@ -379,8 +202,6 @@ def from_file(ctx: ConversionContext, model_name):
# limitations under the License. # limitations under the License.
""" Conversion script for the LDM checkpoints. """ """ Conversion script for the LDM checkpoints. """
def get_db_models():
return []
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
""" """
@ -1075,7 +896,7 @@ def convert_open_clip_checkpoint(checkpoint):
if 'cond_stage_model.model.text_projection' in checkpoint: if 'cond_stage_model.model.text_projection' in checkpoint:
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0]) d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
else: else:
print("No projection shape found, setting to 1024") logger.debug("No projection shape found, setting to 1024")
d_model = 1024 d_model = 1024
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
@ -1130,7 +951,7 @@ def replace_symlinks(path, base):
blob_path = None blob_path = None
if blob_path is None: if blob_path is None:
print("NO BLOB") logger.debug("NO BLOB")
return return
os.replace(blob_path, path) os.replace(blob_path, path)
elif os.path.isdir(path): elif os.path.isdir(path):
@ -1140,7 +961,6 @@ def replace_symlinks(path, base):
def download_model(db_config: DreamboothConfig, token): def download_model(db_config: DreamboothConfig, token):
tmp_dir = os.path.join(db_config.model_dir, "src") tmp_dir = os.path.join(db_config.model_dir, "src")
working_dir = db_config.pretrained_model_name_or_path
hub_url = db_config.src hub_url = db_config.src
if "http" in hub_url or "huggingface.co" in hub_url: if "http" in hub_url or "huggingface.co" in hub_url:
@ -1155,7 +975,7 @@ def download_model(db_config: DreamboothConfig, token):
) )
if repo_info.sha is None: if repo_info.sha is None:
print("Unable to fetch repo?") logger.warning("Unable to fetch repo?")
return None, None return None, None
siblings = repo_info.siblings siblings = repo_info.siblings
@ -1208,10 +1028,10 @@ def download_model(db_config: DreamboothConfig, token):
if files_to_fetch and config_file: if files_to_fetch and config_file:
files_to_fetch.append(config_file) files_to_fetch.append(config_file)
print(f"Fetching files: {files_to_fetch}") logger.info(f"Fetching files: {files_to_fetch}")
if not len(files_to_fetch): if not len(files_to_fetch):
print("Nothing to fetch!") logger.debug("Nothing to fetch!")
return None, None return None, None
@ -1296,9 +1116,6 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type) return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
print("Could not find valid config. Returning default v1 config.")
return get_config_path(model_versions["v1"], train_types["default"], config_base_name, prediction_type="epsilon")
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="", def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True): new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
@ -1352,7 +1169,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if db_config is not None: if db_config is not None:
original_config_file = config original_config_file = config
if model_info is not None: if model_info is not None:
print("Got model info.") logger.debug("Got model info.")
if ".ckpt" in model_info or ".safetensors" in model_info: if ".ckpt" in model_info or ".safetensors" in model_info:
# Set this to false, because we have a checkpoint where we can *maybe* get a revision. # Set this to false, because we have a checkpoint where we can *maybe* get a revision.
from_hub = False from_hub = False
@ -1360,28 +1177,26 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
checkpoint_file = model_info checkpoint_file = model_info
else: else:
msg = "Unable to fetch model from hub." msg = "Unable to fetch model from hub."
print(msg) logger.warning(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg return "", "", 0, 0, "", "", "", "", image_size, "", msg
reset_safe = False
try: try:
checkpoint = None checkpoint = None
map_location = torch.device("cpu") map_location = torch.device("cpu")
# Try to determine if v1 or v2 model if we have a ckpt # Try to determine if v1 or v2 model if we have a ckpt
if not from_hub: if not from_hub:
print("Loading model from checkpoint.") logger.info("Loading model from checkpoint.")
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
os.environ["SAFETENSORS_FAST_GPU"] = "1" os.environ["SAFETENSORS_FAST_GPU"] = "1"
try: try:
print("Loading safetensors...") logger.debug("Loading safetensors...")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu") checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e: except Exception as e:
checkpoint = torch.jit.load(checkpoint_file) checkpoint = torch.jit.load(checkpoint_file)
else: else:
print("Loading ckpt...") logger.debug("Loading ckpt...")
checkpoint = torch.load(checkpoint_file, map_location=map_location) checkpoint = torch.load(checkpoint_file, map_location=map_location)
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
@ -1401,7 +1216,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not is_512: if not is_512:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
print("Setting upcast_attention") logger.debug("Setting upcast_attention")
upcast_attention = True upcast_attention = True
v2 = True v2 = True
else: else:
@ -1410,15 +1225,15 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet") unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
try: try:
unet = UNet2DConditionModel.from_pretrained(unet_dir) unet = UNet2DConditionModel.from_pretrained(unet_dir)
print("Loaded unet.") logger.debug("Loaded unet.")
unet_dict = unet.state_dict() unet_dict = unet.state_dict()
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight" key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024: if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
print("We got v2!") logger.debug("UNet using v2 parameters.")
v2 = True v2 = True
except: except:
print("Exception loading unet!") logger.error("Exception loading unet!")
traceback.print_exc() traceback.print_exc()
if v2 and not is_512: if v2 and not is_512:
@ -1428,7 +1243,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
original_config_file = get_config_file(train_unfrozen, v2, prediction_type) original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
print(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
db_config.resolution = image_size db_config.resolution = image_size
db_config.lifetime_revision = revision db_config.lifetime_revision = revision
db_config.epoch = epoch db_config.epoch = epoch
@ -1438,7 +1253,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
db_config.save() db_config.save()
return return
print(f"{'v2' if v2 else 'v1'} model loaded.") logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
# Use existing YAML if present # Use existing YAML if present
if checkpoint_file is not None: if checkpoint_file is not None:
@ -1447,10 +1262,10 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
original_config_file = config_check original_config_file = config_check
if original_config_file is None or not os.path.exists(original_config_file): if original_config_file is None or not os.path.exists(original_config_file):
print("Unable to select a config file: %s" % (original_config_file)) logger.warning("Unable to select a config file: %s" % (original_config_file))
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file." return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
print(f"Trying to load: {original_config_file}") logger.debug(f"Trying to load: {original_config_file}")
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config.model.params.timesteps
@ -1489,7 +1304,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
print("Converting unet...") logger.info("Converting UNet...")
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
@ -1501,14 +1316,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
db_config.has_ema = has_ema db_config.has_ema = has_ema
db_config.save() db_config.save()
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
print("Converting vae...")
logger.info("Converting VAE...")
# Convert the VAE model. # Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
print("Converting text encoder...")
logger.info("Converting text encoder...")
# Convert the text model. # Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder": if text_model_type == "FrozenOpenCLIPEmbedder":
@ -1557,17 +1374,17 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler) scheduler=scheduler)
except Exception as e: except Exception as e:
print(f"Exception setting up output: {e}") logger.error(f"Exception setting up output: {e}")
pipe = None pipe = None
traceback.print_exc() traceback.print_exc()
if pipe is None or db_config is None: if pipe is None or db_config is None:
msg = "Pipeline or config is not set, unable to continue." msg = "Pipeline or config is not set, unable to continue."
print(msg) logger.error(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg return "", "", 0, 0, "", "", "", "", image_size, "", msg
else: else:
resolution = db_config.resolution resolution = db_config.resolution
print("Saving diffusion model...") logger.info("Saving diffusion model...")
pipe.save_pretrained(db_config.pretrained_model_name_or_path) pipe.save_pretrained(db_config.pretrained_model_name_or_path)
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}" result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
model_dir = db_config.model_dir model_dir = db_config.model_dir
@ -1576,12 +1393,12 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
src = db_config.src src = db_config.src
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file): if original_config_file is not None and os.path.exists(original_config_file):
logger.warn("copying original config: %s -> %s", original_config_file, db_config.model_dir) logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
shutil.copy(original_config_file, db_config.model_dir) shutil.copy(original_config_file, db_config.model_dir)
basename = os.path.basename(original_config_file) basename = os.path.basename(original_config_file)
new_ex_path = os.path.join(db_config.model_dir, basename) new_ex_path = os.path.join(db_config.model_dir, basename)
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml") new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
logger.warn("copying model config to new name: %s -> %s", new_ex_path, new_name) logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name)
if os.path.exists(new_name): if os.path.exists(new_name):
os.remove(new_name) os.remove(new_name)
os.rename(new_ex_path, new_name) os.rename(new_ex_path, new_name)
@ -1601,27 +1418,35 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
os.makedirs(rem_dir) os.makedirs(rem_dir)
print(result_status) logger.info(result_status)
return return
def convert_diffusion_original(ctx: ConversionContext, model_name: str, tensor_file: str, opset: int, half: bool): def convert_diffusion_original(
model_path = os.path.join(ctx.model_path, model_name) ctx: ConversionContext,
torch_name = model_name.replace("onnx", "torch") model: ModelDict,
torch_path = os.path.join(ctx.model_path, torch_name) source: str,
working_name = os.path.join(ctx.model_path, torch_name, "working") ):
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", model_name, tensor_file, model_path) name = model["name"]
source = source or model["source"]
if os.path.exists(model_path): dest = os.path.join(ctx.model_path, name)
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", name, source, dest)
if os.path.exists(dest):
logger.info("ONNX pipeline already exists, skipping.") logger.info("ONNX pipeline already exists, skipping.")
return return
torch_name = name + "-torch"
torch_path = os.path.join(ctx.cache_path, torch_name)
working_name = os.path.join(ctx.cache_path, torch_name, "working")
if os.path.exists(torch_path): if os.path.exists(torch_path):
logger.info("Torch pipeline already exists, reusing.") logger.info("Torch pipeline already exists, reusing.")
else: else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", tensor_file, torch_path) logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, tensor_file, from_hub=False) extract_checkpoint(ctx, torch_name, source, from_hub=False)
logger.info("Converted original Diffusers checkpoint to Torch model.") logger.info("Converted original Diffusers checkpoint to Torch model.")
convert_diffusion_stable(ctx, model_path, working_name, opset, half, None) convert_diffusion_stable(ctx, model, working_name)
logger.info("ONNX pipeline saved to %s", model_name) logger.info("ONNX pipeline saved to %s", name)

View File

@ -1,19 +1,22 @@
from logging import getLogger
from os import mkdir, path
from pathlib import Path
from shutil import rmtree
from typing import Dict
import torch
from diffusers import ( from diffusers import (
OnnxRuntimeModel, OnnxRuntimeModel,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from torch.onnx import export
from logging import getLogger
from shutil import rmtree
import torch
from os import path, mkdir
from pathlib import Path
from onnx import load, save_model from onnx import load, save_model
from torch.onnx import export
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from .utils import ConversionContext from .utils import ConversionContext
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
logger = getLogger(__name__) logger = getLogger(__name__)
@ -48,21 +51,21 @@ def onnx_export(
@torch.no_grad() @torch.no_grad()
def convert_diffusion_stable( def convert_diffusion_stable(
ctx: ConversionContext, ctx: ConversionContext,
name: str, model: Dict,
url: str, source: str,
opset: int,
half: bool,
token: str,
single_vae: bool = False,
): ):
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
dtype = torch.float16 if half else torch.float32 name = model.get("name")
source = source or model.get("source")
single_vae = model.get("single_vae")
dtype = torch.float16 if ctx.half else torch.float32
dest_path = path.join(ctx.model_path, name) dest_path = path.join(ctx.model_path, name)
# diffusers go into a directory rather than .onnx file # diffusers go into a directory rather than .onnx file
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path) logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path)
if single_vae: if single_vae:
logger.info("converting model with single VAE") logger.info("converting model with single VAE")
@ -71,13 +74,16 @@ def convert_diffusion_stable(
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping.")
return return
if half and ctx.training_device != "cuda": if ctx.half and ctx.training_device != "cuda":
raise ValueError( raise ValueError(
"Half precision model export is only supported on GPUs with CUDA" "Half precision model export is only supported on GPUs with CUDA"
) )
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
url, torch_dtype=dtype, use_auth_token=token source,
torch_dtype=dtype,
use_auth_token=ctx.token,
# cache_dir=path.join(ctx.cache_path, name)
).to(ctx.training_device) ).to(ctx.training_device)
output_path = Path(dest_path) output_path = Path(dest_path)
@ -94,14 +100,16 @@ def convert_diffusion_stable(
onnx_export( onnx_export(
pipeline.text_encoder, pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)), model_args=(
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)
),
output_path=output_path / "text_encoder" / "model.onnx", output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"], ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"], output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={ dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"}, "input_ids": {0: "batch", 1: "sequence"},
}, },
opset=opset, opset=ctx.opset,
) )
del pipeline.text_encoder del pipeline.text_encoder
@ -113,7 +121,9 @@ def convert_diffusion_stable(
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int) unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
else: else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool) unet_scale = torch.tensor(False).to(
device=ctx.training_device, dtype=torch.bool
)
unet_in_channels = pipeline.unet.config.in_channels unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size unet_sample_size = pipeline.unet.config.sample_size
@ -139,7 +149,7 @@ def convert_diffusion_stable(
"timestep": {0: "batch"}, "timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"}, "encoder_hidden_states": {0: "batch", 1: "sequence"},
}, },
opset=opset, opset=ctx.opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
) )
unet_model_path = str(unet_path.absolute().as_posix()) unet_model_path = str(unet_path.absolute().as_posix())
@ -182,7 +192,7 @@ def convert_diffusion_stable(
dynamic_axes={ dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
}, },
opset=opset, opset=ctx.opset,
) )
else: else:
# VAE ENCODER # VAE ENCODER
@ -207,7 +217,7 @@ def convert_diffusion_stable(
dynamic_axes={ dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
}, },
opset=opset, opset=ctx.opset,
) )
# VAE DECODER # VAE DECODER
@ -230,7 +240,7 @@ def convert_diffusion_stable(
dynamic_axes={ dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
}, },
opset=opset, opset=ctx.opset,
) )
del pipeline.vae del pipeline.vae
@ -261,7 +271,7 @@ def convert_diffusion_stable(
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"}, "images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
}, },
opset=opset, opset=ctx.opset,
) )
del pipeline.safety_checker del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained( safety_checker = OnnxRuntimeModel.from_pretrained(
@ -312,4 +322,3 @@ def convert_diffusion_stable(
) )
logger.info("ONNX pipeline is loadable") logger.info("ONNX pipeline is loadable")

View File

@ -1,32 +1,34 @@
import torch from logging import getLogger
from os import path
from shutil import copyfile from shutil import copyfile
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export from torch.onnx import export
from os import path
from logging import getLogger from .utils import ConversionContext, ModelDict
from basicsr.archs.rrdbnet_arch import RRDBNet
from .utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
@torch.no_grad() @torch.no_grad()
def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int): def convert_upscale_resrgan(
dest_path = path.join(ctx.model_path, name + ".pth") ctx: ConversionContext,
dest_onnx = path.join(ctx.model_path, name + ".onnx") model: ModelDict,
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx) source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")
if path.isfile(dest_onnx): dest = path.join(ctx.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping.")
return return
if not path.isfile(dest_path):
logger.info("PTH model not found, downloading...")
download_path = load_file_from_url(
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
)
copyfile(download_path, dest_path)
logger.info("loading and training model") logger.info("loading and training model")
model = RRDBNet( model = RRDBNet(
num_in_ch=3, num_in_ch=3,
@ -37,7 +39,7 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
scale=scale, scale=scale,
) )
torch_model = torch.load(dest_path, map_location=ctx.map_location) torch_model = torch.load(source, map_location=ctx.map_location)
if "params_ema" in torch_model: if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"]) model.load_state_dict(torch_model["params_ema"])
else: else:
@ -54,15 +56,15 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
"output": {2: "width", 3: "height"}, "output": {2: "width", 3: "height"},
} }
logger.info("exporting ONNX model to %s", dest_onnx) logger.info("exporting ONNX model to %s", dest)
export( export(
model, model,
rng, rng,
dest_onnx, dest,
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
opset_version=opset, opset_version=ctx.opset,
export_params=True, export_params=True,
) )
logger.info("Real ESRGAN exported to ONNX successfully.") logger.info("Real ESRGAN exported to ONNX successfully.")

View File

@ -1,7 +1,129 @@
import shutil
from functools import partial
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict, Union, List, Optional, Tuple
import requests
import torch import torch
from tqdm.auto import tqdm
logger = getLogger(__name__)
ModelDict = Dict[str, Union[str, int]]
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext: class ConversionContext:
def __init__(self, model_path: str, device: str) -> None: def __init__(
self,
model_path: str,
device: str,
cache_path: Optional[str] = None,
half: Optional[bool] = False,
opset: Optional[int] = None,
token: Optional[str] = None,
) -> None:
self.model_path = model_path self.model_path = model_path
self.cache_path = cache_path or path.join(model_path, ".cache")
self.training_device = device self.training_device = device
self.map_location = torch.device(device) self.map_location = torch.device(device)
self.half = half
self.opset = opset
self.token = token
def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls:
dest_path = Path(dest).expanduser().resolve()
dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists():
logger.info("Destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(url, stream=True, allow_redirects=True)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RuntimeError(
"Request to %s failed with status code: %s" % (url, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
return str(dest_path.absolute())
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
return {
"name": name,
"source": source,
"half": half,
"opset": opset,
"scale": scale,
}
else:
return model
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
single_vae = rest[0] if len(rest) > 0 else False
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
return {
"name": name,
"source": source,
"half": half,
"opset": opset,
"single_vae": single_vae,
}
else:
return model
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
return {
"name": name,
"source": source,
"half": half,
"opset": opset,
"scale": scale,
}
else:
return model
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
def source_format(model: Dict) -> Optional[str]:
if "format" in model:
return model["format"]
if "source" in model:
ext = path.splitext(model["source"])
if ext in known_formats:
return ext
return None

View File

@ -649,10 +649,10 @@ def chain():
return error_reply("chain pipeline must have a body") return error_reply("chain pipeline must have a body")
data = yaml.safe_load(body) data = yaml.safe_load(body)
with open("./schema.yaml", "r") as f: with open("./schemas/chain.yaml", "r") as f:
schema = yaml.safe_load(f.read()) schema = yaml.safe_load(f.read())
logger.info("validating chain request: %s against %s", data, schema) logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema) validate(data, schema)
# get defaults from the regular parameters # get defaults from the regular parameters

View File

@ -1,5 +1,5 @@
[tool.black] [tool.black]
force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py''' force-exclude = '''/(diffusion_original|lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py'''
[tool.isort] [tool.isort]
profile = "black" profile = "black"

View File

@ -1,4 +1,4 @@
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml $id: https://github.com/ssube/onnx-web/blob/main/api/schemas/chain.yaml
$schema: https://json-schema.org/draft/2020-12/schema $schema: https://json-schema.org/draft/2020-12/schema
$defs: $defs:

65
api/schemas/extras.yaml Normal file
View File

@ -0,0 +1,65 @@
$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/extras.yaml
$schema: https://json-schema.org/draft/2020-12/schema
$defs:
legacy_tuple:
type: array
items:
oneOf:
- type: string
- type: number
base_model:
type: object
required: [name, source]
properties:
format:
type: string
enum: [onnx, pth, ckpt, safetensors]
half:
type: boolean
name:
type: string
opset:
type: number
source:
type: string
correction_model:
allOf:
- $ref: "#/$defs/base_model"
diffusion_model:
allOf:
- $ref: "#/$defs/base_model"
upscaling_model:
allOf:
- $ref: "#/$defs/base_model"
- type: object
required: [scale]
properties:
scale:
type: number
type: object
additionalProperties: False
properties:
diffusion:
type: array
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/diffusion_model"
correction:
type: array
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/correction_model"
upscaling:
type: array
items:
oneOf:
- $ref: "#/$defs/legacy_tuple"
- $ref: "#/$defs/upscaling_model"

View File

@ -58,6 +58,7 @@
"rocm", "rocm",
"RRDB", "RRDB",
"runwayml", "runwayml",
"safetensors",
"scandir", "scandir",
"scipy", "scipy",
"scrollback", "scrollback",