feat(api): add a way to download models from civitai or other https sources (#117)
This commit is contained in:
parent
b3e4076775
commit
9f202486c2
|
@ -1,9 +1,23 @@
|
|||
{
|
||||
"diffusion": [
|
||||
["diffusion-knollingcase", "Aybeeceedee/knollingcase"],
|
||||
["diffusion-openjourney", "prompthero/openjourney"],
|
||||
["diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"],
|
||||
["diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"]
|
||||
{
|
||||
"name": "diffusion-knollingcase",
|
||||
"source": "Aybeeceedee/knollingcase"
|
||||
},
|
||||
{
|
||||
"name": "diffusion-openjourney",
|
||||
"source": "prompthero/openjourney"
|
||||
},
|
||||
{
|
||||
"name": "diffusion-stablydiffused-aesthetic-v2-6",
|
||||
"source": "civitai://6266?type=Pruned%20Model&format=SafeTensor",
|
||||
"format": "safetensors"
|
||||
},
|
||||
{
|
||||
"name": "diffusion-unstable-ink-dream-v6",
|
||||
"source": "civitai://5796",
|
||||
"format": "safetensors"
|
||||
}
|
||||
],
|
||||
"correction": [],
|
||||
"upscaling": []
|
||||
|
|
|
@ -1,9 +1,3 @@
|
|||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .diffusion_original import convert_diffusion_original
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .upscale_resrgan import convert_upscale_resrgan
|
||||
from .utils import ConversionContext
|
||||
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from json import loads
|
||||
|
@ -11,9 +5,17 @@ from logging import getLogger
|
|||
from os import environ, makedirs, path
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from yaml import safe_load
|
||||
from jsonschema import validate, ValidationError
|
||||
|
||||
import torch
|
||||
|
||||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .diffusion_original import convert_diffusion_original
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .upscale_resrgan import convert_upscale_resrgan
|
||||
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling
|
||||
|
||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||
warnings.filterwarnings(
|
||||
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
||||
|
@ -29,20 +31,39 @@ Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
model_sources: Dict[str, Tuple[str, str]] = {
|
||||
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
|
||||
}
|
||||
|
||||
model_source_huggingface = "huggingface://"
|
||||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
"diffusion": [
|
||||
# v1.x
|
||||
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
||||
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
||||
# v2.x
|
||||
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
||||
(
|
||||
"stable-diffusion-onnx-v2-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
"stable-diffusion-onnx-v1-5",
|
||||
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
|
||||
),
|
||||
# (
|
||||
# "stable-diffusion-onnx-v1-inpainting",
|
||||
# model_source_huggingface + "runwayml/stable-diffusion-inpainting",
|
||||
# ),
|
||||
# v2.x
|
||||
# (
|
||||
# "stable-diffusion-onnx-v2-1",
|
||||
# model_source_huggingface + "stabilityai/stable-diffusion-2-1",
|
||||
# ),
|
||||
# (
|
||||
# "stable-diffusion-onnx-v2-inpainting",
|
||||
# model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting",
|
||||
# ),
|
||||
# TODO: should have its own converter
|
||||
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
||||
(
|
||||
"upscaling-stable-diffusion-x4",
|
||||
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
|
||||
True,
|
||||
),
|
||||
],
|
||||
"correction": [
|
||||
(
|
||||
|
@ -79,35 +100,86 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
|||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def load_models(args, ctx: ConversionContext, models: Models):
|
||||
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
|
||||
cache_name = path.join(ctx.cache_path, name)
|
||||
if format is not None:
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
cache_name = "%s.%s" % (cache_name, format)
|
||||
|
||||
for proto in model_sources:
|
||||
api_name, api_root = model_sources.get(proto)
|
||||
if source.startswith(proto):
|
||||
api_source = api_root % (source.removeprefix(proto))
|
||||
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
|
||||
return download_progress([(api_source, cache_name)])
|
||||
|
||||
if source.startswith(model_source_huggingface):
|
||||
hub_source = source.removeprefix(model_source_huggingface)
|
||||
logger.info("Downloading model from Huggingface Hub: %s", hub_source)
|
||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||
return hub_source
|
||||
elif source.startswith("https://"):
|
||||
logger.info("Downloading model from: %s", source)
|
||||
return download_progress([(source, cache_name)])
|
||||
elif source.startswith("http://"):
|
||||
logger.warning("Downloading model from insecure source: %s", source)
|
||||
return download_progress([(source, cache_name)])
|
||||
elif source.startswith(path.sep) or source.startswith("."):
|
||||
logger.info("Using local model: %s", source)
|
||||
return source
|
||||
else:
|
||||
logger.info("Unknown model location, using path as provided: %s", source)
|
||||
return source
|
||||
|
||||
|
||||
def convert_models(ctx: ConversionContext, args, models: Models):
|
||||
if args.diffusion:
|
||||
for source in models.get("diffusion"):
|
||||
name, file = source
|
||||
for model in models.get("diffusion"):
|
||||
model = tuple_to_diffusion(model)
|
||||
name = model.get("name")
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
logger.info("Skipping model: %s", name)
|
||||
else:
|
||||
if file.endswith(".safetensors") or file.endswith(".ckpt"):
|
||||
convert_diffusion_original(ctx, *source, args.opset, args.half)
|
||||
format = source_format(model)
|
||||
source = fetch_model(ctx, name, model["source"], format=format)
|
||||
|
||||
if format in ["safetensors", "ckpt"]:
|
||||
convert_diffusion_original(
|
||||
ctx,
|
||||
model,
|
||||
source,
|
||||
)
|
||||
else:
|
||||
# TODO: make this a parameter in the JSON/dict
|
||||
single_vae = "upscaling" in source[0]
|
||||
convert_diffusion_stable(
|
||||
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae
|
||||
ctx,
|
||||
model,
|
||||
source,
|
||||
)
|
||||
|
||||
if args.upscaling:
|
||||
for source in models.get("upscaling"):
|
||||
if source[0] in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
for model in models.get("upscaling"):
|
||||
model = tuple_to_upscaling(model)
|
||||
name = model.get("name")
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("Skipping model: %s", name)
|
||||
else:
|
||||
convert_upscale_resrgan(ctx, *source, args.opset)
|
||||
format = source_format(model)
|
||||
source = fetch_model(ctx, name, model["source"], format=format)
|
||||
convert_upscale_resrgan(ctx, model, source)
|
||||
|
||||
if args.correction:
|
||||
for source in models.get("correction"):
|
||||
if source[0] in args.skip:
|
||||
logger.info("Skipping model: %s", source[0])
|
||||
for model in models.get("correction"):
|
||||
model = tuple_to_correction(model)
|
||||
name = model.get("name")
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("Skipping model: %s", name)
|
||||
else:
|
||||
convert_correction_gfpgan(ctx, *source, args.opset)
|
||||
format = source_format(model)
|
||||
source = fetch_model(ctx, name, model["source"], format=format)
|
||||
convert_correction_gfpgan(ctx, model, source)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
|
@ -146,7 +218,7 @@ def main() -> int:
|
|||
args = parser.parse_args()
|
||||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
ctx = ConversionContext(model_path, training_device)
|
||||
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
|
||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||
|
||||
if not path.exists(model_path):
|
||||
|
@ -154,16 +226,26 @@ def main() -> int:
|
|||
makedirs(model_path)
|
||||
|
||||
logger.info("Converting base models.")
|
||||
load_models(args, ctx, base_models)
|
||||
convert_models(ctx, args, base_models)
|
||||
|
||||
for file in args.extras:
|
||||
if file is not None and file != "":
|
||||
logger.info("Loading extra models from %s", file)
|
||||
try:
|
||||
with open(file, "r") as f:
|
||||
data = loads(f.read())
|
||||
data = safe_load(f.read())
|
||||
|
||||
with open("./schemas/extras.yaml", "r") as f:
|
||||
schema = safe_load(f.read())
|
||||
|
||||
logger.debug("validating chain request: %s against %s", data, schema)
|
||||
|
||||
try:
|
||||
validate(data, schema)
|
||||
logger.info("Converting extra models.")
|
||||
load_models(args, ctx, data)
|
||||
convert_models(ctx, args, data)
|
||||
except ValidationError as err:
|
||||
logger.error("Invalid data in extras file: %s", err)
|
||||
except Exception as err:
|
||||
logger.error("Error converting extra models: %s", err)
|
||||
|
||||
|
|
|
@ -1,31 +1,34 @@
|
|||
import torch
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.onnx import export
|
||||
from os import path
|
||||
from logging import getLogger
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from .utils import ConversionContext
|
||||
|
||||
from .utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(ctx.model_path, name + ".pth")
|
||||
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
||||
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
@torch.no_grad()
|
||||
def convert_correction_gfpgan(
|
||||
ctx: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
scale = model.get("scale")
|
||||
|
||||
dest = path.join(ctx.model_path, name + ".onnx")
|
||||
logger.info("converting GFPGAN model: %s -> %s", name, dest)
|
||||
|
||||
if path.isfile(dest):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
|
@ -36,7 +39,7 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
|
|||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
||||
torch_model = torch.load(source, map_location=ctx.map_location)
|
||||
# TODO: make sure strict=False is safe here
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
|
@ -54,15 +57,15 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
|
|||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
logger.info("exporting ONNX model to %s", dest)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest_onnx,
|
||||
dest,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
opset_version=ctx.opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("GFPGAN exported to ONNX successfully.")
|
||||
|
|
|
@ -11,6 +11,17 @@
|
|||
# TODO: ask about license before merging
|
||||
###
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
from logging import getLogger
|
||||
from typing import Dict, List
|
||||
|
||||
import huggingface_hub.utils.tqdm
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
|
@ -20,95 +31,34 @@ from diffusers import (
|
|||
HeunDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PaintByExamplePipeline,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel, PaintByExamplePipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||
LDMBertConfig,
|
||||
LDMBertModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from logging import getLogger
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import huggingface_hub.utils.tqdm
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import safetensors.torch
|
||||
import shutil
|
||||
import torch
|
||||
import traceback
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionConfig,
|
||||
)
|
||||
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .utils import ConversionContext
|
||||
from .utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_images():
|
||||
return []
|
||||
|
||||
class Concept(BaseModel):
|
||||
class_data_dir: str = ""
|
||||
class_guidance_scale: float = 7.5
|
||||
class_infer_steps: int = 60
|
||||
class_negative_prompt: str = ""
|
||||
class_prompt: str = ""
|
||||
class_token: str = ""
|
||||
instance_data_dir: str = ""
|
||||
instance_prompt: str = ""
|
||||
instance_token: str = ""
|
||||
is_valid: bool = False
|
||||
n_save_sample: int = 1
|
||||
num_class_images: int = 0
|
||||
num_class_images_per: int = 0
|
||||
sample_seed: int = -1
|
||||
save_guidance_scale: float = 7.5
|
||||
save_infer_steps: int = 60
|
||||
save_sample_negative_prompt: str = ""
|
||||
save_sample_prompt: str = ""
|
||||
save_sample_template: str = ""
|
||||
|
||||
def __init__(self, input_dict: Union[Dict, None] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if input_dict is not None:
|
||||
self.load_params(input_dict)
|
||||
if self.is_valid and self.num_class_images != 0:
|
||||
if self.num_class_images_per == 0:
|
||||
images = get_images(self.instance_data_dir)
|
||||
if len(images) < self.num_class_images * 2:
|
||||
self.num_class_images_per = 1
|
||||
else:
|
||||
self.num_class_images_per = self.num_class_images // len(images)
|
||||
self.num_class_images = 0
|
||||
|
||||
def to_dict(self):
|
||||
return self.dict()
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def load_params(self, params_dict):
|
||||
for key, value in params_dict.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
if self.instance_data_dir:
|
||||
self.is_valid = os.path.isdir(self.instance_data_dir)
|
||||
else:
|
||||
self.is_valid = False
|
||||
|
||||
|
||||
# Keys to save, replacing our dumb __init__ method
|
||||
save_keys = []
|
||||
|
||||
# Keys to return to the ui when Load Settings is clicked.
|
||||
ui_keys = []
|
||||
|
||||
|
||||
def sanitize_name(name):
|
||||
return "".join(x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
|
@ -200,7 +150,7 @@ class DreamboothConfig(BaseModel):
|
|||
|
||||
super().__init__(**kwargs)
|
||||
model_name = sanitize_name(model_name)
|
||||
model_dir = os.path.join(ctx.model_path, model_name)
|
||||
model_dir = os.path.join(ctx.cache_path, model_name)
|
||||
working_dir = os.path.join(model_dir, "working")
|
||||
|
||||
if not os.path.exists(working_dir):
|
||||
|
@ -214,7 +164,6 @@ class DreamboothConfig(BaseModel):
|
|||
self.scheduler = scheduler
|
||||
self.v2 = v2
|
||||
|
||||
# Actually save as a file
|
||||
def save(self, backup=False):
|
||||
"""
|
||||
Save the config file
|
||||
|
@ -236,132 +185,6 @@ class DreamboothConfig(BaseModel):
|
|||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
# Pass a dict and return a list of Concept objects
|
||||
def concepts(self, required: int = -1):
|
||||
concepts = []
|
||||
c_idx = 0
|
||||
# If using a file for concepts and not requesting from UI, load from file
|
||||
if self.use_concepts and self.concepts_path and required == -1:
|
||||
concepts_list = concepts_from_file(self.concepts_path)
|
||||
|
||||
# Otherwise, use 'stored' list
|
||||
else:
|
||||
concepts_list = self.concepts_list
|
||||
if required == -1:
|
||||
required = len(concepts_list)
|
||||
|
||||
for concept_dict in concepts_list:
|
||||
concept = Concept(input_dict=concept_dict)
|
||||
if concept.is_valid:
|
||||
if concept.class_data_dir == "" or concept.class_data_dir is None:
|
||||
concept.class_data_dir = os.path.join(self.model_dir, f"classifiers_{c_idx}")
|
||||
concepts.append(concept)
|
||||
c_idx += 1
|
||||
|
||||
missing = len(concepts) - required
|
||||
if missing > 0:
|
||||
concepts.extend([Concept(None)] * missing)
|
||||
return concepts
|
||||
|
||||
# Set default values
|
||||
def check_defaults(self):
|
||||
if self.model_name is not None and self.model_name != "":
|
||||
if self.revision == "" or self.revision is None:
|
||||
self.revision = 0
|
||||
if self.epoch == "" or self.epoch is None:
|
||||
self.epoch = 0
|
||||
self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- "))
|
||||
models_path = "." # TODO: use ctx path
|
||||
model_dir = os.path.join(models_path, self.model_name)
|
||||
working_dir = os.path.join(model_dir, "working")
|
||||
if not os.path.exists(working_dir):
|
||||
os.makedirs(working_dir)
|
||||
self.model_dir = model_dir
|
||||
self.pretrained_model_name_or_path = working_dir
|
||||
|
||||
|
||||
def concepts_from_file(concepts_path: str):
|
||||
concepts = []
|
||||
if os.path.exists(concepts_path) and os.path.isfile(str):
|
||||
try:
|
||||
with open(concepts_path,"r") as concepts_file:
|
||||
concepts_str = concepts_file.read()
|
||||
except Exception as e:
|
||||
print(f"Exception opening concepts file: {e}")
|
||||
else:
|
||||
concepts_str = concepts_path
|
||||
|
||||
try:
|
||||
concepts_data = json.loads(concepts_str)
|
||||
for concept_data in concepts_data:
|
||||
concept = Concept(input_dict=concept_data)
|
||||
if concept.is_valid:
|
||||
concepts.append(concept.__dict__)
|
||||
except Exception as e:
|
||||
print(f"Exception parsing concepts: {e}")
|
||||
return concepts
|
||||
|
||||
|
||||
def save_config(*args):
|
||||
raise Exception("where tho")
|
||||
params = list(args)
|
||||
concept_keys = ["c1_", "c2_", "c3_", "c4_"]
|
||||
model_name = params[0]
|
||||
if model_name is None or model_name == "":
|
||||
print("Invalid model name.")
|
||||
return
|
||||
config = from_file(ctx, model_name)
|
||||
if config is None:
|
||||
config = DreamboothConfig(model_name)
|
||||
params_dict = dict(zip(save_keys, params))
|
||||
concepts_list = []
|
||||
# If using a concepts file/string, keep concepts_list empty.
|
||||
if params_dict["db_use_concepts"] and params_dict["db_concepts_path"]:
|
||||
concepts_list = []
|
||||
params_dict["concepts_list"] = concepts_list
|
||||
else:
|
||||
for concept_key in concept_keys:
|
||||
concept_dict = {}
|
||||
for key, param in params_dict.items():
|
||||
if concept_key in key and param is not None:
|
||||
concept_dict[key.replace(concept_key, "")] = param
|
||||
concept_test = Concept(concept_dict)
|
||||
if concept_test.is_valid:
|
||||
concepts_list.append(concept_test.__dict__)
|
||||
existing_concepts = params_dict["concepts_list"] if "concepts_list" in params_dict else []
|
||||
if len(concepts_list) and not len(existing_concepts):
|
||||
params_dict["concepts_list"] = concepts_list
|
||||
|
||||
config.load_params(params_dict)
|
||||
config.save()
|
||||
|
||||
|
||||
def from_file(ctx: ConversionContext, model_name):
|
||||
"""
|
||||
Load config data from UI
|
||||
Args:
|
||||
model_name: The config to load
|
||||
|
||||
Returns: Dict | None
|
||||
|
||||
"""
|
||||
if model_name == "" or model_name is None:
|
||||
return None
|
||||
|
||||
model_name = sanitize_name(model_name)
|
||||
config_file = os.path.join(ctx.model_path, model_name, "db_config.json")
|
||||
try:
|
||||
with open(config_file, 'r') as openfile:
|
||||
config_dict = json.load(openfile)
|
||||
|
||||
config = DreamboothConfig(model_name)
|
||||
config.load_params(config_dict)
|
||||
return config
|
||||
except Exception as e:
|
||||
print(f"Exception loading config: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
|
@ -379,8 +202,6 @@ def from_file(ctx: ConversionContext, model_name):
|
|||
# limitations under the License.
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
def get_db_models():
|
||||
return []
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
|
@ -1075,7 +896,7 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||
if 'cond_stage_model.model.text_projection' in checkpoint:
|
||||
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
|
||||
else:
|
||||
print("No projection shape found, setting to 1024")
|
||||
logger.debug("No projection shape found, setting to 1024")
|
||||
d_model = 1024
|
||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||
|
||||
|
@ -1130,7 +951,7 @@ def replace_symlinks(path, base):
|
|||
blob_path = None
|
||||
|
||||
if blob_path is None:
|
||||
print("NO BLOB")
|
||||
logger.debug("NO BLOB")
|
||||
return
|
||||
os.replace(blob_path, path)
|
||||
elif os.path.isdir(path):
|
||||
|
@ -1140,7 +961,6 @@ def replace_symlinks(path, base):
|
|||
|
||||
def download_model(db_config: DreamboothConfig, token):
|
||||
tmp_dir = os.path.join(db_config.model_dir, "src")
|
||||
working_dir = db_config.pretrained_model_name_or_path
|
||||
|
||||
hub_url = db_config.src
|
||||
if "http" in hub_url or "huggingface.co" in hub_url:
|
||||
|
@ -1155,7 +975,7 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
)
|
||||
|
||||
if repo_info.sha is None:
|
||||
print("Unable to fetch repo?")
|
||||
logger.warning("Unable to fetch repo?")
|
||||
return None, None
|
||||
|
||||
siblings = repo_info.siblings
|
||||
|
@ -1208,10 +1028,10 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
if files_to_fetch and config_file:
|
||||
files_to_fetch.append(config_file)
|
||||
|
||||
print(f"Fetching files: {files_to_fetch}")
|
||||
logger.info(f"Fetching files: {files_to_fetch}")
|
||||
|
||||
if not len(files_to_fetch):
|
||||
print("Nothing to fetch!")
|
||||
logger.debug("Nothing to fetch!")
|
||||
return None, None
|
||||
|
||||
|
||||
|
@ -1296,9 +1116,6 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
|||
|
||||
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
|
||||
|
||||
print("Could not find valid config. Returning default v1 config.")
|
||||
return get_config_path(model_versions["v1"], train_types["default"], config_base_name, prediction_type="epsilon")
|
||||
|
||||
|
||||
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
|
||||
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
|
||||
|
@ -1352,7 +1169,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
if db_config is not None:
|
||||
original_config_file = config
|
||||
if model_info is not None:
|
||||
print("Got model info.")
|
||||
logger.debug("Got model info.")
|
||||
if ".ckpt" in model_info or ".safetensors" in model_info:
|
||||
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
|
||||
from_hub = False
|
||||
|
@ -1360,28 +1177,26 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
checkpoint_file = model_info
|
||||
else:
|
||||
msg = "Unable to fetch model from hub."
|
||||
print(msg)
|
||||
logger.warning(msg)
|
||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
||||
|
||||
reset_safe = False
|
||||
|
||||
try:
|
||||
checkpoint = None
|
||||
map_location = torch.device("cpu")
|
||||
|
||||
# Try to determine if v1 or v2 model if we have a ckpt
|
||||
if not from_hub:
|
||||
print("Loading model from checkpoint.")
|
||||
logger.info("Loading model from checkpoint.")
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
try:
|
||||
print("Loading safetensors...")
|
||||
logger.debug("Loading safetensors...")
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
except Exception as e:
|
||||
checkpoint = torch.jit.load(checkpoint_file)
|
||||
else:
|
||||
print("Loading ckpt...")
|
||||
logger.debug("Loading ckpt...")
|
||||
checkpoint = torch.load(checkpoint_file, map_location=map_location)
|
||||
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||
|
||||
|
@ -1401,7 +1216,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
if not is_512:
|
||||
# v2.1 needs to upcast attention
|
||||
print("Setting upcast_attention")
|
||||
logger.debug("Setting upcast_attention")
|
||||
upcast_attention = True
|
||||
v2 = True
|
||||
else:
|
||||
|
@ -1410,15 +1225,15 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
|
||||
try:
|
||||
unet = UNet2DConditionModel.from_pretrained(unet_dir)
|
||||
print("Loaded unet.")
|
||||
logger.debug("Loaded unet.")
|
||||
unet_dict = unet.state_dict()
|
||||
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
||||
print("We got v2!")
|
||||
logger.debug("UNet using v2 parameters.")
|
||||
v2 = True
|
||||
|
||||
except:
|
||||
print("Exception loading unet!")
|
||||
logger.error("Exception loading unet!")
|
||||
traceback.print_exc()
|
||||
|
||||
if v2 and not is_512:
|
||||
|
@ -1428,7 +1243,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
|
||||
original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
|
||||
|
||||
print(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
|
||||
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
|
||||
db_config.resolution = image_size
|
||||
db_config.lifetime_revision = revision
|
||||
db_config.epoch = epoch
|
||||
|
@ -1438,7 +1253,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
db_config.save()
|
||||
return
|
||||
|
||||
print(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||
|
||||
# Use existing YAML if present
|
||||
if checkpoint_file is not None:
|
||||
|
@ -1447,10 +1262,10 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
original_config_file = config_check
|
||||
|
||||
if original_config_file is None or not os.path.exists(original_config_file):
|
||||
print("Unable to select a config file: %s" % (original_config_file))
|
||||
logger.warning("Unable to select a config file: %s" % (original_config_file))
|
||||
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
|
||||
|
||||
print(f"Trying to load: {original_config_file}")
|
||||
logger.debug(f"Trying to load: {original_config_file}")
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
|
@ -1489,7 +1304,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
|
||||
print("Converting unet...")
|
||||
logger.info("Converting UNet...")
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
|
@ -1501,14 +1316,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
db_config.has_ema = has_ema
|
||||
db_config.save()
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
print("Converting vae...")
|
||||
|
||||
logger.info("Converting VAE...")
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
print("Converting text encoder...")
|
||||
|
||||
logger.info("Converting text encoder...")
|
||||
# Convert the text model.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||
|
@ -1557,17 +1374,17 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||
scheduler=scheduler)
|
||||
except Exception as e:
|
||||
print(f"Exception setting up output: {e}")
|
||||
logger.error(f"Exception setting up output: {e}")
|
||||
pipe = None
|
||||
traceback.print_exc()
|
||||
|
||||
if pipe is None or db_config is None:
|
||||
msg = "Pipeline or config is not set, unable to continue."
|
||||
print(msg)
|
||||
logger.error(msg)
|
||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
||||
else:
|
||||
resolution = db_config.resolution
|
||||
print("Saving diffusion model...")
|
||||
logger.info("Saving diffusion model...")
|
||||
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
|
||||
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
|
||||
model_dir = db_config.model_dir
|
||||
|
@ -1576,12 +1393,12 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
src = db_config.src
|
||||
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
||||
if original_config_file is not None and os.path.exists(original_config_file):
|
||||
logger.warn("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||
shutil.copy(original_config_file, db_config.model_dir)
|
||||
basename = os.path.basename(original_config_file)
|
||||
new_ex_path = os.path.join(db_config.model_dir, basename)
|
||||
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
|
||||
logger.warn("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
||||
logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
||||
if os.path.exists(new_name):
|
||||
os.remove(new_name)
|
||||
os.rename(new_ex_path, new_name)
|
||||
|
@ -1601,27 +1418,35 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
os.makedirs(rem_dir)
|
||||
|
||||
|
||||
print(result_status)
|
||||
logger.info(result_status)
|
||||
|
||||
return
|
||||
|
||||
def convert_diffusion_original(ctx: ConversionContext, model_name: str, tensor_file: str, opset: int, half: bool):
|
||||
model_path = os.path.join(ctx.model_path, model_name)
|
||||
torch_name = model_name.replace("onnx", "torch")
|
||||
torch_path = os.path.join(ctx.model_path, torch_name)
|
||||
working_name = os.path.join(ctx.model_path, torch_name, "working")
|
||||
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", model_name, tensor_file, model_path)
|
||||
def convert_diffusion_original(
|
||||
ctx: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
name = model["name"]
|
||||
source = source or model["source"]
|
||||
|
||||
if os.path.exists(model_path):
|
||||
dest = os.path.join(ctx.model_path, name)
|
||||
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", name, source, dest)
|
||||
|
||||
if os.path.exists(dest):
|
||||
logger.info("ONNX pipeline already exists, skipping.")
|
||||
return
|
||||
|
||||
torch_name = name + "-torch"
|
||||
torch_path = os.path.join(ctx.cache_path, torch_name)
|
||||
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
||||
|
||||
if os.path.exists(torch_path):
|
||||
logger.info("Torch pipeline already exists, reusing.")
|
||||
else:
|
||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", tensor_file, torch_path)
|
||||
extract_checkpoint(ctx, torch_name, tensor_file, from_hub=False)
|
||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
||||
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
||||
|
||||
convert_diffusion_stable(ctx, model_path, working_name, opset, half, None)
|
||||
logger.info("ONNX pipeline saved to %s", model_name)
|
||||
convert_diffusion_stable(ctx, model, working_name)
|
||||
logger.info("ONNX pipeline saved to %s", name)
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
from logging import getLogger
|
||||
from os import mkdir, path
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from torch.onnx import export
|
||||
from logging import getLogger
|
||||
|
||||
from shutil import rmtree
|
||||
import torch
|
||||
from os import path, mkdir
|
||||
from pathlib import Path
|
||||
from onnx import load, save_model
|
||||
from torch.onnx import export
|
||||
|
||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .utils import ConversionContext
|
||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -47,22 +50,22 @@ def onnx_export(
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_stable(
|
||||
ctx: ConversionContext,
|
||||
name: str,
|
||||
url: str,
|
||||
opset: int,
|
||||
half: bool,
|
||||
token: str,
|
||||
single_vae: bool = False,
|
||||
ctx: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
):
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
dtype = torch.float16 if half else torch.float32
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
single_vae = model.get("single_vae")
|
||||
|
||||
dtype = torch.float16 if ctx.half else torch.float32
|
||||
dest_path = path.join(ctx.model_path, name)
|
||||
|
||||
# diffusers go into a directory rather than .onnx file
|
||||
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path)
|
||||
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path)
|
||||
|
||||
if single_vae:
|
||||
logger.info("converting model with single VAE")
|
||||
|
@ -71,13 +74,16 @@ def convert_diffusion_stable(
|
|||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if half and ctx.training_device != "cuda":
|
||||
if ctx.half and ctx.training_device != "cuda":
|
||||
raise ValueError(
|
||||
"Half precision model export is only supported on GPUs with CUDA"
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
url, torch_dtype=dtype, use_auth_token=token
|
||||
source,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=ctx.token,
|
||||
# cache_dir=path.join(ctx.cache_path, name)
|
||||
).to(ctx.training_device)
|
||||
output_path = Path(dest_path)
|
||||
|
||||
|
@ -94,14 +100,16 @@ def convert_diffusion_stable(
|
|||
onnx_export(
|
||||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)),
|
||||
model_args=(
|
||||
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)
|
||||
),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=opset,
|
||||
opset=ctx.opset,
|
||||
)
|
||||
del pipeline.text_encoder
|
||||
|
||||
|
@ -113,7 +121,9 @@ def convert_diffusion_stable(
|
|||
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
|
||||
else:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||
unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool)
|
||||
unet_scale = torch.tensor(False).to(
|
||||
device=ctx.training_device, dtype=torch.bool
|
||||
)
|
||||
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
|
@ -139,7 +149,7 @@ def convert_diffusion_stable(
|
|||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=opset,
|
||||
opset=ctx.opset,
|
||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
|
@ -182,7 +192,7 @@ def convert_diffusion_stable(
|
|||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
opset=ctx.opset,
|
||||
)
|
||||
else:
|
||||
# VAE ENCODER
|
||||
|
@ -207,7 +217,7 @@ def convert_diffusion_stable(
|
|||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
opset=ctx.opset,
|
||||
)
|
||||
|
||||
# VAE DECODER
|
||||
|
@ -230,7 +240,7 @@ def convert_diffusion_stable(
|
|||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
opset=ctx.opset,
|
||||
)
|
||||
|
||||
del pipeline.vae
|
||||
|
@ -261,7 +271,7 @@ def convert_diffusion_stable(
|
|||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
opset=ctx.opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(
|
||||
|
@ -312,4 +322,3 @@ def convert_diffusion_stable(
|
|||
)
|
||||
|
||||
logger.info("ONNX pipeline is loadable")
|
||||
|
||||
|
|
|
@ -1,32 +1,34 @@
|
|||
import torch
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.onnx import export
|
||||
from os import path
|
||||
from logging import getLogger
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from .utils import ConversionContext
|
||||
|
||||
from .utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(ctx.model_path, name + ".pth")
|
||||
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
||||
def convert_upscale_resrgan(
|
||||
ctx: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
scale = model.get("scale")
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
dest = path.join(ctx.model_path, name + ".onnx")
|
||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
|
||||
|
||||
if path.isfile(dest):
|
||||
logger.info("ONNX model already exists, skipping.")
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
logger.info("PTH model not found, downloading...")
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
||||
)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
|
@ -37,7 +39,7 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
|
|||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
||||
torch_model = torch.load(source, map_location=ctx.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"])
|
||||
else:
|
||||
|
@ -54,15 +56,15 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
|
|||
"output": {2: "width", 3: "height"},
|
||||
}
|
||||
|
||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
||||
logger.info("exporting ONNX model to %s", dest)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest_onnx,
|
||||
dest,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=opset,
|
||||
opset_version=ctx.opset,
|
||||
export_params=True,
|
||||
)
|
||||
logger.info("Real ESRGAN exported to ONNX successfully.")
|
||||
|
|
|
@ -1,7 +1,129 @@
|
|||
import shutil
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
ModelDict = Dict[str, Union[str, int]]
|
||||
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||
|
||||
|
||||
class ConversionContext:
|
||||
def __init__(self, model_path: str, device: str) -> None:
|
||||
self.model_path = model_path
|
||||
self.training_device = device
|
||||
self.map_location = torch.device(device)
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str,
|
||||
cache_path: Optional[str] = None,
|
||||
half: Optional[bool] = False,
|
||||
opset: Optional[int] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||
self.training_device = device
|
||||
self.map_location = torch.device(device)
|
||||
self.half = half
|
||||
self.opset = opset
|
||||
self.token = token
|
||||
|
||||
|
||||
def download_progress(urls: List[Tuple[str, str]]):
|
||||
for url, dest in urls:
|
||||
dest_path = Path(dest).expanduser().resolve()
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest_path.exists():
|
||||
logger.info("Destination already exists: %s", dest_path)
|
||||
return str(dest_path.absolute())
|
||||
|
||||
req = requests.get(url, stream=True, allow_redirects=True)
|
||||
if req.status_code != 200:
|
||||
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||
raise RuntimeError(
|
||||
"Request to %s failed with status code: %s" % (url, req.status_code)
|
||||
)
|
||||
|
||||
total = int(req.headers.get("Content-Length", 0))
|
||||
desc = "unknown" if total == 0 else ""
|
||||
req.raw.read = partial(req.raw.read, decode_content=True)
|
||||
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
|
||||
with dest_path.open("wb") as f:
|
||||
shutil.copyfileobj(data, f)
|
||||
|
||||
return str(dest_path.absolute())
|
||||
|
||||
|
||||
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
scale = rest[0] if len(rest) > 0 else 1
|
||||
half = rest[0] if len(rest) > 0 else False
|
||||
opset = rest[0] if len(rest) > 0 else None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"source": source,
|
||||
"half": half,
|
||||
"opset": opset,
|
||||
"scale": scale,
|
||||
}
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
single_vae = rest[0] if len(rest) > 0 else False
|
||||
half = rest[0] if len(rest) > 0 else False
|
||||
opset = rest[0] if len(rest) > 0 else None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"source": source,
|
||||
"half": half,
|
||||
"opset": opset,
|
||||
"single_vae": single_vae,
|
||||
}
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
scale = rest[0] if len(rest) > 0 else 1
|
||||
half = rest[0] if len(rest) > 0 else False
|
||||
opset = rest[0] if len(rest) > 0 else None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"source": source,
|
||||
"half": half,
|
||||
"opset": opset,
|
||||
"scale": scale,
|
||||
}
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
|
||||
def source_format(model: Dict) -> Optional[str]:
|
||||
if "format" in model:
|
||||
return model["format"]
|
||||
|
||||
if "source" in model:
|
||||
ext = path.splitext(model["source"])
|
||||
if ext in known_formats:
|
||||
return ext
|
||||
|
||||
return None
|
||||
|
|
|
@ -649,10 +649,10 @@ def chain():
|
|||
return error_reply("chain pipeline must have a body")
|
||||
|
||||
data = yaml.safe_load(body)
|
||||
with open("./schema.yaml", "r") as f:
|
||||
with open("./schemas/chain.yaml", "r") as f:
|
||||
schema = yaml.safe_load(f.read())
|
||||
|
||||
logger.info("validating chain request: %s against %s", data, schema)
|
||||
logger.debug("validating chain request: %s against %s", data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[tool.black]
|
||||
force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py'''
|
||||
force-exclude = '''/(diffusion_original|lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py'''
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml
|
||||
$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/chain.yaml
|
||||
$schema: https://json-schema.org/draft/2020-12/schema
|
||||
|
||||
$defs:
|
|
@ -0,0 +1,65 @@
|
|||
$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/extras.yaml
|
||||
$schema: https://json-schema.org/draft/2020-12/schema
|
||||
|
||||
$defs:
|
||||
legacy_tuple:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: number
|
||||
|
||||
base_model:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
format:
|
||||
type: string
|
||||
enum: [onnx, pth, ckpt, safetensors]
|
||||
half:
|
||||
type: boolean
|
||||
name:
|
||||
type: string
|
||||
opset:
|
||||
type: number
|
||||
source:
|
||||
type: string
|
||||
|
||||
correction_model:
|
||||
allOf:
|
||||
- $ref: "#/$defs/base_model"
|
||||
|
||||
diffusion_model:
|
||||
allOf:
|
||||
- $ref: "#/$defs/base_model"
|
||||
|
||||
upscaling_model:
|
||||
allOf:
|
||||
- $ref: "#/$defs/base_model"
|
||||
- type: object
|
||||
required: [scale]
|
||||
properties:
|
||||
scale:
|
||||
type: number
|
||||
|
||||
type: object
|
||||
additionalProperties: False
|
||||
properties:
|
||||
diffusion:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- $ref: "#/$defs/legacy_tuple"
|
||||
- $ref: "#/$defs/diffusion_model"
|
||||
correction:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- $ref: "#/$defs/legacy_tuple"
|
||||
- $ref: "#/$defs/correction_model"
|
||||
upscaling:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- $ref: "#/$defs/legacy_tuple"
|
||||
- $ref: "#/$defs/upscaling_model"
|
|
@ -58,6 +58,7 @@
|
|||
"rocm",
|
||||
"RRDB",
|
||||
"runwayml",
|
||||
"safetensors",
|
||||
"scandir",
|
||||
"scipy",
|
||||
"scrollback",
|
||||
|
|
Loading…
Reference in New Issue