feat(api): add a way to download models from civitai or other https sources (#117)
This commit is contained in:
parent
b3e4076775
commit
9f202486c2
|
@ -1,9 +1,23 @@
|
||||||
{
|
{
|
||||||
"diffusion": [
|
"diffusion": [
|
||||||
["diffusion-knollingcase", "Aybeeceedee/knollingcase"],
|
{
|
||||||
["diffusion-openjourney", "prompthero/openjourney"],
|
"name": "diffusion-knollingcase",
|
||||||
["diffusion-stably-diffused-onnx-v2-6", "../models/tensors/stablydiffuseds_26.safetensors"],
|
"source": "Aybeeceedee/knollingcase"
|
||||||
["diffusion-unstable-ink-dream-onnx-v6", "../models/tensors/unstableinkdream_v6.safetensors"]
|
},
|
||||||
|
{
|
||||||
|
"name": "diffusion-openjourney",
|
||||||
|
"source": "prompthero/openjourney"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "diffusion-stablydiffused-aesthetic-v2-6",
|
||||||
|
"source": "civitai://6266?type=Pruned%20Model&format=SafeTensor",
|
||||||
|
"format": "safetensors"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "diffusion-unstable-ink-dream-v6",
|
||||||
|
"source": "civitai://5796",
|
||||||
|
"format": "safetensors"
|
||||||
|
}
|
||||||
],
|
],
|
||||||
"correction": [],
|
"correction": [],
|
||||||
"upscaling": []
|
"upscaling": []
|
||||||
|
|
|
@ -1,9 +1,3 @@
|
||||||
from .correction_gfpgan import convert_correction_gfpgan
|
|
||||||
from .diffusion_original import convert_diffusion_original
|
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
|
||||||
from .upscale_resrgan import convert_upscale_resrgan
|
|
||||||
from .utils import ConversionContext
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from json import loads
|
from json import loads
|
||||||
|
@ -11,9 +5,17 @@ from logging import getLogger
|
||||||
from os import environ, makedirs, path
|
from os import environ, makedirs, path
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from yaml import safe_load
|
||||||
|
from jsonschema import validate, ValidationError
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .correction_gfpgan import convert_correction_gfpgan
|
||||||
|
from .diffusion_original import convert_diffusion_original
|
||||||
|
from .diffusion_stable import convert_diffusion_stable
|
||||||
|
from .upscale_resrgan import convert_upscale_resrgan
|
||||||
|
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling
|
||||||
|
|
||||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
"ignore", ".*The shape inference of prim::Constant type is missing.*"
|
||||||
|
@ -29,20 +31,39 @@ Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
model_sources: Dict[str, Tuple[str, str]] = {
|
||||||
|
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
|
||||||
|
}
|
||||||
|
|
||||||
|
model_source_huggingface = "huggingface://"
|
||||||
|
|
||||||
# recommended models
|
# recommended models
|
||||||
base_models: Models = {
|
base_models: Models = {
|
||||||
"diffusion": [
|
"diffusion": [
|
||||||
# v1.x
|
# v1.x
|
||||||
("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"),
|
|
||||||
("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"),
|
|
||||||
# v2.x
|
|
||||||
("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"),
|
|
||||||
(
|
(
|
||||||
"stable-diffusion-onnx-v2-inpainting",
|
"stable-diffusion-onnx-v1-5",
|
||||||
"stabilityai/stable-diffusion-2-inpainting",
|
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
|
||||||
),
|
),
|
||||||
|
# (
|
||||||
|
# "stable-diffusion-onnx-v1-inpainting",
|
||||||
|
# model_source_huggingface + "runwayml/stable-diffusion-inpainting",
|
||||||
|
# ),
|
||||||
|
# v2.x
|
||||||
|
# (
|
||||||
|
# "stable-diffusion-onnx-v2-1",
|
||||||
|
# model_source_huggingface + "stabilityai/stable-diffusion-2-1",
|
||||||
|
# ),
|
||||||
|
# (
|
||||||
|
# "stable-diffusion-onnx-v2-inpainting",
|
||||||
|
# model_source_huggingface + "stabilityai/stable-diffusion-2-inpainting",
|
||||||
|
# ),
|
||||||
# TODO: should have its own converter
|
# TODO: should have its own converter
|
||||||
("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"),
|
(
|
||||||
|
"upscaling-stable-diffusion-x4",
|
||||||
|
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
|
||||||
|
True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
"correction": [
|
"correction": [
|
||||||
(
|
(
|
||||||
|
@ -79,35 +100,86 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def load_models(args, ctx: ConversionContext, models: Models):
|
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
|
||||||
|
cache_name = path.join(ctx.cache_path, name)
|
||||||
|
if format is not None:
|
||||||
|
# add an extension if possible, some of the conversion code checks for it
|
||||||
|
cache_name = "%s.%s" % (cache_name, format)
|
||||||
|
|
||||||
|
for proto in model_sources:
|
||||||
|
api_name, api_root = model_sources.get(proto)
|
||||||
|
if source.startswith(proto):
|
||||||
|
api_source = api_root % (source.removeprefix(proto))
|
||||||
|
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
|
||||||
|
return download_progress([(api_source, cache_name)])
|
||||||
|
|
||||||
|
if source.startswith(model_source_huggingface):
|
||||||
|
hub_source = source.removeprefix(model_source_huggingface)
|
||||||
|
logger.info("Downloading model from Huggingface Hub: %s", hub_source)
|
||||||
|
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||||
|
return hub_source
|
||||||
|
elif source.startswith("https://"):
|
||||||
|
logger.info("Downloading model from: %s", source)
|
||||||
|
return download_progress([(source, cache_name)])
|
||||||
|
elif source.startswith("http://"):
|
||||||
|
logger.warning("Downloading model from insecure source: %s", source)
|
||||||
|
return download_progress([(source, cache_name)])
|
||||||
|
elif source.startswith(path.sep) or source.startswith("."):
|
||||||
|
logger.info("Using local model: %s", source)
|
||||||
|
return source
|
||||||
|
else:
|
||||||
|
logger.info("Unknown model location, using path as provided: %s", source)
|
||||||
|
return source
|
||||||
|
|
||||||
|
|
||||||
|
def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
if args.diffusion:
|
if args.diffusion:
|
||||||
for source in models.get("diffusion"):
|
for model in models.get("diffusion"):
|
||||||
name, file = source
|
model = tuple_to_diffusion(model)
|
||||||
|
name = model.get("name")
|
||||||
|
|
||||||
if name in args.skip:
|
if name in args.skip:
|
||||||
logger.info("Skipping model: %s", source[0])
|
logger.info("Skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
if file.endswith(".safetensors") or file.endswith(".ckpt"):
|
format = source_format(model)
|
||||||
convert_diffusion_original(ctx, *source, args.opset, args.half)
|
source = fetch_model(ctx, name, model["source"], format=format)
|
||||||
|
|
||||||
|
if format in ["safetensors", "ckpt"]:
|
||||||
|
convert_diffusion_original(
|
||||||
|
ctx,
|
||||||
|
model,
|
||||||
|
source,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# TODO: make this a parameter in the JSON/dict
|
|
||||||
single_vae = "upscaling" in source[0]
|
|
||||||
convert_diffusion_stable(
|
convert_diffusion_stable(
|
||||||
ctx, *source, args.opset, args.half, args.token, single_vae=single_vae
|
ctx,
|
||||||
|
model,
|
||||||
|
source,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.upscaling:
|
if args.upscaling:
|
||||||
for source in models.get("upscaling"):
|
for model in models.get("upscaling"):
|
||||||
if source[0] in args.skip:
|
model = tuple_to_upscaling(model)
|
||||||
logger.info("Skipping model: %s", source[0])
|
name = model.get("name")
|
||||||
|
|
||||||
|
if name in args.skip:
|
||||||
|
logger.info("Skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
convert_upscale_resrgan(ctx, *source, args.opset)
|
format = source_format(model)
|
||||||
|
source = fetch_model(ctx, name, model["source"], format=format)
|
||||||
|
convert_upscale_resrgan(ctx, model, source)
|
||||||
|
|
||||||
if args.correction:
|
if args.correction:
|
||||||
for source in models.get("correction"):
|
for model in models.get("correction"):
|
||||||
if source[0] in args.skip:
|
model = tuple_to_correction(model)
|
||||||
logger.info("Skipping model: %s", source[0])
|
name = model.get("name")
|
||||||
|
|
||||||
|
if name in args.skip:
|
||||||
|
logger.info("Skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
convert_correction_gfpgan(ctx, *source, args.opset)
|
format = source_format(model)
|
||||||
|
source = fetch_model(ctx, name, model["source"], format=format)
|
||||||
|
convert_correction_gfpgan(ctx, model, source)
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
|
@ -146,7 +218,7 @@ def main() -> int:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("CLI arguments: %s", args)
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
ctx = ConversionContext(model_path, training_device)
|
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
|
||||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||||
|
|
||||||
if not path.exists(model_path):
|
if not path.exists(model_path):
|
||||||
|
@ -154,16 +226,26 @@ def main() -> int:
|
||||||
makedirs(model_path)
|
makedirs(model_path)
|
||||||
|
|
||||||
logger.info("Converting base models.")
|
logger.info("Converting base models.")
|
||||||
load_models(args, ctx, base_models)
|
convert_models(ctx, args, base_models)
|
||||||
|
|
||||||
for file in args.extras:
|
for file in args.extras:
|
||||||
if file is not None and file != "":
|
if file is not None and file != "":
|
||||||
logger.info("Loading extra models from %s", file)
|
logger.info("Loading extra models from %s", file)
|
||||||
try:
|
try:
|
||||||
with open(file, "r") as f:
|
with open(file, "r") as f:
|
||||||
data = loads(f.read())
|
data = safe_load(f.read())
|
||||||
|
|
||||||
|
with open("./schemas/extras.yaml", "r") as f:
|
||||||
|
schema = safe_load(f.read())
|
||||||
|
|
||||||
|
logger.debug("validating chain request: %s against %s", data, schema)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate(data, schema)
|
||||||
logger.info("Converting extra models.")
|
logger.info("Converting extra models.")
|
||||||
load_models(args, ctx, data)
|
convert_models(ctx, args, data)
|
||||||
|
except ValidationError as err:
|
||||||
|
logger.error("Invalid data in extras file: %s", err)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error("Error converting extra models: %s", err)
|
logger.error("Error converting extra models: %s", err)
|
||||||
|
|
||||||
|
|
|
@ -1,31 +1,34 @@
|
||||||
import torch
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
from os import path
|
|
||||||
from logging import getLogger
|
from .utils import ConversionContext, ModelDict
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from .utils import ConversionContext
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
|
||||||
dest_path = path.join(ctx.model_path, name + ".pth")
|
|
||||||
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
|
||||||
logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx)
|
|
||||||
|
|
||||||
if path.isfile(dest_onnx):
|
@torch.no_grad()
|
||||||
|
def convert_correction_gfpgan(
|
||||||
|
ctx: ConversionContext,
|
||||||
|
model: ModelDict,
|
||||||
|
source: str,
|
||||||
|
):
|
||||||
|
name = model.get("name")
|
||||||
|
source = source or model.get("source")
|
||||||
|
scale = model.get("scale")
|
||||||
|
|
||||||
|
dest = path.join(ctx.model_path, name + ".onnx")
|
||||||
|
logger.info("converting GFPGAN model: %s -> %s", name, dest)
|
||||||
|
|
||||||
|
if path.isfile(dest):
|
||||||
logger.info("ONNX model already exists, skipping.")
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not path.isfile(dest_path):
|
|
||||||
logger.info("PTH model not found, downloading...")
|
|
||||||
download_path = load_file_from_url(
|
|
||||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
|
||||||
)
|
|
||||||
copyfile(download_path, dest_path)
|
|
||||||
|
|
||||||
logger.info("loading and training model")
|
logger.info("loading and training model")
|
||||||
model = RRDBNet(
|
model = RRDBNet(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
|
@ -36,7 +39,7 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
torch_model = torch.load(source, map_location=ctx.map_location)
|
||||||
# TODO: make sure strict=False is safe here
|
# TODO: make sure strict=False is safe here
|
||||||
if "params_ema" in torch_model:
|
if "params_ema" in torch_model:
|
||||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||||
|
@ -54,15 +57,15 @@ def convert_correction_gfpgan(ctx: ConversionContext, name: str, url: str, scale
|
||||||
"output": {2: "width", 3: "height"},
|
"output": {2: "width", 3: "height"},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
logger.info("exporting ONNX model to %s", dest)
|
||||||
export(
|
export(
|
||||||
model,
|
model,
|
||||||
rng,
|
rng,
|
||||||
dest_onnx,
|
dest,
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
opset_version=opset,
|
opset_version=ctx.opset,
|
||||||
export_params=True,
|
export_params=True,
|
||||||
)
|
)
|
||||||
logger.info("GFPGAN exported to ONNX successfully.")
|
logger.info("GFPGAN exported to ONNX successfully.")
|
||||||
|
|
|
@ -11,6 +11,17 @@
|
||||||
# TODO: ask about license before merging
|
# TODO: ask about license before merging
|
||||||
###
|
###
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import traceback
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import huggingface_hub.utils.tqdm
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
@ -20,95 +31,34 @@ from diffusers import (
|
||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
LDMTextToImagePipeline,
|
LDMTextToImagePipeline,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
|
PaintByExamplePipeline,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
UNet2DConditionModel, PaintByExamplePipeline,
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||||
|
LDMBertConfig,
|
||||||
|
LDMBertModel,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from logging import getLogger
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
from transformers import (
|
||||||
from typing import Dict, List, Union
|
AutoFeatureExtractor,
|
||||||
|
BertTokenizerFast,
|
||||||
import huggingface_hub.utils.tqdm
|
CLIPTextModel,
|
||||||
import json
|
CLIPTokenizer,
|
||||||
import os
|
CLIPVisionConfig,
|
||||||
import re
|
)
|
||||||
import safetensors.torch
|
|
||||||
import shutil
|
|
||||||
import torch
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
from .diffusion_stable import convert_diffusion_stable
|
||||||
from .utils import ConversionContext
|
from .utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_images():
|
|
||||||
return []
|
|
||||||
|
|
||||||
class Concept(BaseModel):
|
|
||||||
class_data_dir: str = ""
|
|
||||||
class_guidance_scale: float = 7.5
|
|
||||||
class_infer_steps: int = 60
|
|
||||||
class_negative_prompt: str = ""
|
|
||||||
class_prompt: str = ""
|
|
||||||
class_token: str = ""
|
|
||||||
instance_data_dir: str = ""
|
|
||||||
instance_prompt: str = ""
|
|
||||||
instance_token: str = ""
|
|
||||||
is_valid: bool = False
|
|
||||||
n_save_sample: int = 1
|
|
||||||
num_class_images: int = 0
|
|
||||||
num_class_images_per: int = 0
|
|
||||||
sample_seed: int = -1
|
|
||||||
save_guidance_scale: float = 7.5
|
|
||||||
save_infer_steps: int = 60
|
|
||||||
save_sample_negative_prompt: str = ""
|
|
||||||
save_sample_prompt: str = ""
|
|
||||||
save_sample_template: str = ""
|
|
||||||
|
|
||||||
def __init__(self, input_dict: Union[Dict, None] = None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
if input_dict is not None:
|
|
||||||
self.load_params(input_dict)
|
|
||||||
if self.is_valid and self.num_class_images != 0:
|
|
||||||
if self.num_class_images_per == 0:
|
|
||||||
images = get_images(self.instance_data_dir)
|
|
||||||
if len(images) < self.num_class_images * 2:
|
|
||||||
self.num_class_images_per = 1
|
|
||||||
else:
|
|
||||||
self.num_class_images_per = self.num_class_images // len(images)
|
|
||||||
self.num_class_images = 0
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return self.dict()
|
|
||||||
|
|
||||||
def to_json(self):
|
|
||||||
return json.dumps(self.to_dict())
|
|
||||||
|
|
||||||
def load_params(self, params_dict):
|
|
||||||
for key, value in params_dict.items():
|
|
||||||
if hasattr(self, key):
|
|
||||||
setattr(self, key, value)
|
|
||||||
if self.instance_data_dir:
|
|
||||||
self.is_valid = os.path.isdir(self.instance_data_dir)
|
|
||||||
else:
|
|
||||||
self.is_valid = False
|
|
||||||
|
|
||||||
|
|
||||||
# Keys to save, replacing our dumb __init__ method
|
|
||||||
save_keys = []
|
|
||||||
|
|
||||||
# Keys to return to the ui when Load Settings is clicked.
|
|
||||||
ui_keys = []
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_name(name):
|
def sanitize_name(name):
|
||||||
return "".join(x for x in name if (x.isalnum() or x in "._- "))
|
return "".join(x for x in name if (x.isalnum() or x in "._- "))
|
||||||
|
|
||||||
|
@ -200,7 +150,7 @@ class DreamboothConfig(BaseModel):
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
model_name = sanitize_name(model_name)
|
model_name = sanitize_name(model_name)
|
||||||
model_dir = os.path.join(ctx.model_path, model_name)
|
model_dir = os.path.join(ctx.cache_path, model_name)
|
||||||
working_dir = os.path.join(model_dir, "working")
|
working_dir = os.path.join(model_dir, "working")
|
||||||
|
|
||||||
if not os.path.exists(working_dir):
|
if not os.path.exists(working_dir):
|
||||||
|
@ -214,7 +164,6 @@ class DreamboothConfig(BaseModel):
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.v2 = v2
|
self.v2 = v2
|
||||||
|
|
||||||
# Actually save as a file
|
|
||||||
def save(self, backup=False):
|
def save(self, backup=False):
|
||||||
"""
|
"""
|
||||||
Save the config file
|
Save the config file
|
||||||
|
@ -236,132 +185,6 @@ class DreamboothConfig(BaseModel):
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
# Pass a dict and return a list of Concept objects
|
|
||||||
def concepts(self, required: int = -1):
|
|
||||||
concepts = []
|
|
||||||
c_idx = 0
|
|
||||||
# If using a file for concepts and not requesting from UI, load from file
|
|
||||||
if self.use_concepts and self.concepts_path and required == -1:
|
|
||||||
concepts_list = concepts_from_file(self.concepts_path)
|
|
||||||
|
|
||||||
# Otherwise, use 'stored' list
|
|
||||||
else:
|
|
||||||
concepts_list = self.concepts_list
|
|
||||||
if required == -1:
|
|
||||||
required = len(concepts_list)
|
|
||||||
|
|
||||||
for concept_dict in concepts_list:
|
|
||||||
concept = Concept(input_dict=concept_dict)
|
|
||||||
if concept.is_valid:
|
|
||||||
if concept.class_data_dir == "" or concept.class_data_dir is None:
|
|
||||||
concept.class_data_dir = os.path.join(self.model_dir, f"classifiers_{c_idx}")
|
|
||||||
concepts.append(concept)
|
|
||||||
c_idx += 1
|
|
||||||
|
|
||||||
missing = len(concepts) - required
|
|
||||||
if missing > 0:
|
|
||||||
concepts.extend([Concept(None)] * missing)
|
|
||||||
return concepts
|
|
||||||
|
|
||||||
# Set default values
|
|
||||||
def check_defaults(self):
|
|
||||||
if self.model_name is not None and self.model_name != "":
|
|
||||||
if self.revision == "" or self.revision is None:
|
|
||||||
self.revision = 0
|
|
||||||
if self.epoch == "" or self.epoch is None:
|
|
||||||
self.epoch = 0
|
|
||||||
self.model_name = "".join(x for x in self.model_name if (x.isalnum() or x in "._- "))
|
|
||||||
models_path = "." # TODO: use ctx path
|
|
||||||
model_dir = os.path.join(models_path, self.model_name)
|
|
||||||
working_dir = os.path.join(model_dir, "working")
|
|
||||||
if not os.path.exists(working_dir):
|
|
||||||
os.makedirs(working_dir)
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.pretrained_model_name_or_path = working_dir
|
|
||||||
|
|
||||||
|
|
||||||
def concepts_from_file(concepts_path: str):
|
|
||||||
concepts = []
|
|
||||||
if os.path.exists(concepts_path) and os.path.isfile(str):
|
|
||||||
try:
|
|
||||||
with open(concepts_path,"r") as concepts_file:
|
|
||||||
concepts_str = concepts_file.read()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception opening concepts file: {e}")
|
|
||||||
else:
|
|
||||||
concepts_str = concepts_path
|
|
||||||
|
|
||||||
try:
|
|
||||||
concepts_data = json.loads(concepts_str)
|
|
||||||
for concept_data in concepts_data:
|
|
||||||
concept = Concept(input_dict=concept_data)
|
|
||||||
if concept.is_valid:
|
|
||||||
concepts.append(concept.__dict__)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception parsing concepts: {e}")
|
|
||||||
return concepts
|
|
||||||
|
|
||||||
|
|
||||||
def save_config(*args):
|
|
||||||
raise Exception("where tho")
|
|
||||||
params = list(args)
|
|
||||||
concept_keys = ["c1_", "c2_", "c3_", "c4_"]
|
|
||||||
model_name = params[0]
|
|
||||||
if model_name is None or model_name == "":
|
|
||||||
print("Invalid model name.")
|
|
||||||
return
|
|
||||||
config = from_file(ctx, model_name)
|
|
||||||
if config is None:
|
|
||||||
config = DreamboothConfig(model_name)
|
|
||||||
params_dict = dict(zip(save_keys, params))
|
|
||||||
concepts_list = []
|
|
||||||
# If using a concepts file/string, keep concepts_list empty.
|
|
||||||
if params_dict["db_use_concepts"] and params_dict["db_concepts_path"]:
|
|
||||||
concepts_list = []
|
|
||||||
params_dict["concepts_list"] = concepts_list
|
|
||||||
else:
|
|
||||||
for concept_key in concept_keys:
|
|
||||||
concept_dict = {}
|
|
||||||
for key, param in params_dict.items():
|
|
||||||
if concept_key in key and param is not None:
|
|
||||||
concept_dict[key.replace(concept_key, "")] = param
|
|
||||||
concept_test = Concept(concept_dict)
|
|
||||||
if concept_test.is_valid:
|
|
||||||
concepts_list.append(concept_test.__dict__)
|
|
||||||
existing_concepts = params_dict["concepts_list"] if "concepts_list" in params_dict else []
|
|
||||||
if len(concepts_list) and not len(existing_concepts):
|
|
||||||
params_dict["concepts_list"] = concepts_list
|
|
||||||
|
|
||||||
config.load_params(params_dict)
|
|
||||||
config.save()
|
|
||||||
|
|
||||||
|
|
||||||
def from_file(ctx: ConversionContext, model_name):
|
|
||||||
"""
|
|
||||||
Load config data from UI
|
|
||||||
Args:
|
|
||||||
model_name: The config to load
|
|
||||||
|
|
||||||
Returns: Dict | None
|
|
||||||
|
|
||||||
"""
|
|
||||||
if model_name == "" or model_name is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
model_name = sanitize_name(model_name)
|
|
||||||
config_file = os.path.join(ctx.model_path, model_name, "db_config.json")
|
|
||||||
try:
|
|
||||||
with open(config_file, 'r') as openfile:
|
|
||||||
config_dict = json.load(openfile)
|
|
||||||
|
|
||||||
config = DreamboothConfig(model_name)
|
|
||||||
config.load_params(config_dict)
|
|
||||||
return config
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception loading config: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2022 The HuggingFace Inc. team.
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
@ -379,8 +202,6 @@ def from_file(ctx: ConversionContext, model_name):
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Conversion script for the LDM checkpoints. """
|
""" Conversion script for the LDM checkpoints. """
|
||||||
|
|
||||||
def get_db_models():
|
|
||||||
return []
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
|
@ -1075,7 +896,7 @@ def convert_open_clip_checkpoint(checkpoint):
|
||||||
if 'cond_stage_model.model.text_projection' in checkpoint:
|
if 'cond_stage_model.model.text_projection' in checkpoint:
|
||||||
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
|
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
|
||||||
else:
|
else:
|
||||||
print("No projection shape found, setting to 1024")
|
logger.debug("No projection shape found, setting to 1024")
|
||||||
d_model = 1024
|
d_model = 1024
|
||||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||||
|
|
||||||
|
@ -1130,7 +951,7 @@ def replace_symlinks(path, base):
|
||||||
blob_path = None
|
blob_path = None
|
||||||
|
|
||||||
if blob_path is None:
|
if blob_path is None:
|
||||||
print("NO BLOB")
|
logger.debug("NO BLOB")
|
||||||
return
|
return
|
||||||
os.replace(blob_path, path)
|
os.replace(blob_path, path)
|
||||||
elif os.path.isdir(path):
|
elif os.path.isdir(path):
|
||||||
|
@ -1140,7 +961,6 @@ def replace_symlinks(path, base):
|
||||||
|
|
||||||
def download_model(db_config: DreamboothConfig, token):
|
def download_model(db_config: DreamboothConfig, token):
|
||||||
tmp_dir = os.path.join(db_config.model_dir, "src")
|
tmp_dir = os.path.join(db_config.model_dir, "src")
|
||||||
working_dir = db_config.pretrained_model_name_or_path
|
|
||||||
|
|
||||||
hub_url = db_config.src
|
hub_url = db_config.src
|
||||||
if "http" in hub_url or "huggingface.co" in hub_url:
|
if "http" in hub_url or "huggingface.co" in hub_url:
|
||||||
|
@ -1155,7 +975,7 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
)
|
)
|
||||||
|
|
||||||
if repo_info.sha is None:
|
if repo_info.sha is None:
|
||||||
print("Unable to fetch repo?")
|
logger.warning("Unable to fetch repo?")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
siblings = repo_info.siblings
|
siblings = repo_info.siblings
|
||||||
|
@ -1208,10 +1028,10 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
if files_to_fetch and config_file:
|
if files_to_fetch and config_file:
|
||||||
files_to_fetch.append(config_file)
|
files_to_fetch.append(config_file)
|
||||||
|
|
||||||
print(f"Fetching files: {files_to_fetch}")
|
logger.info(f"Fetching files: {files_to_fetch}")
|
||||||
|
|
||||||
if not len(files_to_fetch):
|
if not len(files_to_fetch):
|
||||||
print("Nothing to fetch!")
|
logger.debug("Nothing to fetch!")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1296,9 +1116,6 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
||||||
|
|
||||||
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
|
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
|
||||||
|
|
||||||
print("Could not find valid config. Returning default v1 config.")
|
|
||||||
return get_config_path(model_versions["v1"], train_types["default"], config_base_name, prediction_type="epsilon")
|
|
||||||
|
|
||||||
|
|
||||||
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
|
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
|
||||||
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
|
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
|
||||||
|
@ -1352,7 +1169,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
if db_config is not None:
|
if db_config is not None:
|
||||||
original_config_file = config
|
original_config_file = config
|
||||||
if model_info is not None:
|
if model_info is not None:
|
||||||
print("Got model info.")
|
logger.debug("Got model info.")
|
||||||
if ".ckpt" in model_info or ".safetensors" in model_info:
|
if ".ckpt" in model_info or ".safetensors" in model_info:
|
||||||
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
|
# Set this to false, because we have a checkpoint where we can *maybe* get a revision.
|
||||||
from_hub = False
|
from_hub = False
|
||||||
|
@ -1360,28 +1177,26 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
checkpoint_file = model_info
|
checkpoint_file = model_info
|
||||||
else:
|
else:
|
||||||
msg = "Unable to fetch model from hub."
|
msg = "Unable to fetch model from hub."
|
||||||
print(msg)
|
logger.warning(msg)
|
||||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
||||||
|
|
||||||
reset_safe = False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
map_location = torch.device("cpu")
|
map_location = torch.device("cpu")
|
||||||
|
|
||||||
# Try to determine if v1 or v2 model if we have a ckpt
|
# Try to determine if v1 or v2 model if we have a ckpt
|
||||||
if not from_hub:
|
if not from_hub:
|
||||||
print("Loading model from checkpoint.")
|
logger.info("Loading model from checkpoint.")
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||||
try:
|
try:
|
||||||
print("Loading safetensors...")
|
logger.debug("Loading safetensors...")
|
||||||
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
checkpoint = torch.jit.load(checkpoint_file)
|
checkpoint = torch.jit.load(checkpoint_file)
|
||||||
else:
|
else:
|
||||||
print("Loading ckpt...")
|
logger.debug("Loading ckpt...")
|
||||||
checkpoint = torch.load(checkpoint_file, map_location=map_location)
|
checkpoint = torch.load(checkpoint_file, map_location=map_location)
|
||||||
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||||
|
|
||||||
|
@ -1401,7 +1216,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
if not is_512:
|
if not is_512:
|
||||||
# v2.1 needs to upcast attention
|
# v2.1 needs to upcast attention
|
||||||
print("Setting upcast_attention")
|
logger.debug("Setting upcast_attention")
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
v2 = True
|
v2 = True
|
||||||
else:
|
else:
|
||||||
|
@ -1410,15 +1225,15 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
|
unet_dir = os.path.join(db_config.pretrained_model_name_or_path, "unet")
|
||||||
try:
|
try:
|
||||||
unet = UNet2DConditionModel.from_pretrained(unet_dir)
|
unet = UNet2DConditionModel.from_pretrained(unet_dir)
|
||||||
print("Loaded unet.")
|
logger.debug("Loaded unet.")
|
||||||
unet_dict = unet.state_dict()
|
unet_dict = unet.state_dict()
|
||||||
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name = "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
||||||
print("We got v2!")
|
logger.debug("UNet using v2 parameters.")
|
||||||
v2 = True
|
v2 = True
|
||||||
|
|
||||||
except:
|
except:
|
||||||
print("Exception loading unet!")
|
logger.error("Exception loading unet!")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
if v2 and not is_512:
|
if v2 and not is_512:
|
||||||
|
@ -1428,7 +1243,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
|
|
||||||
original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
|
original_config_file = get_config_file(train_unfrozen, v2, prediction_type)
|
||||||
|
|
||||||
print(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
|
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
|
||||||
db_config.resolution = image_size
|
db_config.resolution = image_size
|
||||||
db_config.lifetime_revision = revision
|
db_config.lifetime_revision = revision
|
||||||
db_config.epoch = epoch
|
db_config.epoch = epoch
|
||||||
|
@ -1438,7 +1253,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
db_config.save()
|
db_config.save()
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"{'v2' if v2 else 'v1'} model loaded.")
|
logger.info(f"{'v2' if v2 else 'v1'} model loaded.")
|
||||||
|
|
||||||
# Use existing YAML if present
|
# Use existing YAML if present
|
||||||
if checkpoint_file is not None:
|
if checkpoint_file is not None:
|
||||||
|
@ -1447,10 +1262,10 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
original_config_file = config_check
|
original_config_file = config_check
|
||||||
|
|
||||||
if original_config_file is None or not os.path.exists(original_config_file):
|
if original_config_file is None or not os.path.exists(original_config_file):
|
||||||
print("Unable to select a config file: %s" % (original_config_file))
|
logger.warning("Unable to select a config file: %s" % (original_config_file))
|
||||||
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
|
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
|
||||||
|
|
||||||
print(f"Trying to load: {original_config_file}")
|
logger.debug(f"Trying to load: {original_config_file}")
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
num_train_timesteps = original_config.model.params.timesteps
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
|
@ -1489,7 +1304,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||||
|
|
||||||
|
|
||||||
print("Converting unet...")
|
logger.info("Converting UNet...")
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||||
unet_config["upcast_attention"] = upcast_attention
|
unet_config["upcast_attention"] = upcast_attention
|
||||||
|
@ -1501,14 +1316,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
db_config.has_ema = has_ema
|
db_config.has_ema = has_ema
|
||||||
db_config.save()
|
db_config.save()
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
print("Converting vae...")
|
|
||||||
|
logger.info("Converting VAE...")
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
print("Converting text encoder...")
|
|
||||||
|
logger.info("Converting text encoder...")
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||||
|
@ -1557,17 +1374,17 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||||
scheduler=scheduler)
|
scheduler=scheduler)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception setting up output: {e}")
|
logger.error(f"Exception setting up output: {e}")
|
||||||
pipe = None
|
pipe = None
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
if pipe is None or db_config is None:
|
if pipe is None or db_config is None:
|
||||||
msg = "Pipeline or config is not set, unable to continue."
|
msg = "Pipeline or config is not set, unable to continue."
|
||||||
print(msg)
|
logger.error(msg)
|
||||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
||||||
else:
|
else:
|
||||||
resolution = db_config.resolution
|
resolution = db_config.resolution
|
||||||
print("Saving diffusion model...")
|
logger.info("Saving diffusion model...")
|
||||||
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
|
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
|
||||||
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
|
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
|
||||||
model_dir = db_config.model_dir
|
model_dir = db_config.model_dir
|
||||||
|
@ -1576,12 +1393,12 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
src = db_config.src
|
src = db_config.src
|
||||||
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
||||||
if original_config_file is not None and os.path.exists(original_config_file):
|
if original_config_file is not None and os.path.exists(original_config_file):
|
||||||
logger.warn("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||||
shutil.copy(original_config_file, db_config.model_dir)
|
shutil.copy(original_config_file, db_config.model_dir)
|
||||||
basename = os.path.basename(original_config_file)
|
basename = os.path.basename(original_config_file)
|
||||||
new_ex_path = os.path.join(db_config.model_dir, basename)
|
new_ex_path = os.path.join(db_config.model_dir, basename)
|
||||||
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
|
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
|
||||||
logger.warn("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
||||||
if os.path.exists(new_name):
|
if os.path.exists(new_name):
|
||||||
os.remove(new_name)
|
os.remove(new_name)
|
||||||
os.rename(new_ex_path, new_name)
|
os.rename(new_ex_path, new_name)
|
||||||
|
@ -1601,27 +1418,35 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
os.makedirs(rem_dir)
|
os.makedirs(rem_dir)
|
||||||
|
|
||||||
|
|
||||||
print(result_status)
|
logger.info(result_status)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def convert_diffusion_original(ctx: ConversionContext, model_name: str, tensor_file: str, opset: int, half: bool):
|
def convert_diffusion_original(
|
||||||
model_path = os.path.join(ctx.model_path, model_name)
|
ctx: ConversionContext,
|
||||||
torch_name = model_name.replace("onnx", "torch")
|
model: ModelDict,
|
||||||
torch_path = os.path.join(ctx.model_path, torch_name)
|
source: str,
|
||||||
working_name = os.path.join(ctx.model_path, torch_name, "working")
|
):
|
||||||
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", model_name, tensor_file, model_path)
|
name = model["name"]
|
||||||
|
source = source or model["source"]
|
||||||
|
|
||||||
if os.path.exists(model_path):
|
dest = os.path.join(ctx.model_path, name)
|
||||||
|
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", name, source, dest)
|
||||||
|
|
||||||
|
if os.path.exists(dest):
|
||||||
logger.info("ONNX pipeline already exists, skipping.")
|
logger.info("ONNX pipeline already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
torch_name = name + "-torch"
|
||||||
|
torch_path = os.path.join(ctx.cache_path, torch_name)
|
||||||
|
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
||||||
|
|
||||||
if os.path.exists(torch_path):
|
if os.path.exists(torch_path):
|
||||||
logger.info("Torch pipeline already exists, reusing.")
|
logger.info("Torch pipeline already exists, reusing.")
|
||||||
else:
|
else:
|
||||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", tensor_file, torch_path)
|
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||||
extract_checkpoint(ctx, torch_name, tensor_file, from_hub=False)
|
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
||||||
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
logger.info("Converted original Diffusers checkpoint to Torch model.")
|
||||||
|
|
||||||
convert_diffusion_stable(ctx, model_path, working_name, opset, half, None)
|
convert_diffusion_stable(ctx, model, working_name)
|
||||||
logger.info("ONNX pipeline saved to %s", model_name)
|
logger.info("ONNX pipeline saved to %s", name)
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from os import mkdir, path
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import rmtree
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
OnnxRuntimeModel,
|
OnnxRuntimeModel,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
from torch.onnx import export
|
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
from shutil import rmtree
|
|
||||||
import torch
|
|
||||||
from os import path, mkdir
|
|
||||||
from pathlib import Path
|
|
||||||
from onnx import load, save_model
|
from onnx import load, save_model
|
||||||
|
from torch.onnx import export
|
||||||
|
|
||||||
|
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
|
)
|
||||||
from .utils import ConversionContext
|
from .utils import ConversionContext
|
||||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -48,21 +51,21 @@ def onnx_export(
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_stable(
|
def convert_diffusion_stable(
|
||||||
ctx: ConversionContext,
|
ctx: ConversionContext,
|
||||||
name: str,
|
model: Dict,
|
||||||
url: str,
|
source: str,
|
||||||
opset: int,
|
|
||||||
half: bool,
|
|
||||||
token: str,
|
|
||||||
single_vae: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
dtype = torch.float16 if half else torch.float32
|
name = model.get("name")
|
||||||
|
source = source or model.get("source")
|
||||||
|
single_vae = model.get("single_vae")
|
||||||
|
|
||||||
|
dtype = torch.float16 if ctx.half else torch.float32
|
||||||
dest_path = path.join(ctx.model_path, name)
|
dest_path = path.join(ctx.model_path, name)
|
||||||
|
|
||||||
# diffusers go into a directory rather than .onnx file
|
# diffusers go into a directory rather than .onnx file
|
||||||
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, url, dest_path)
|
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path)
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.info("converting model with single VAE")
|
logger.info("converting model with single VAE")
|
||||||
|
@ -71,13 +74,16 @@ def convert_diffusion_stable(
|
||||||
logger.info("ONNX model already exists, skipping.")
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if half and ctx.training_device != "cuda":
|
if ctx.half and ctx.training_device != "cuda":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Half precision model export is only supported on GPUs with CUDA"
|
"Half precision model export is only supported on GPUs with CUDA"
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
url, torch_dtype=dtype, use_auth_token=token
|
source,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
use_auth_token=ctx.token,
|
||||||
|
# cache_dir=path.join(ctx.cache_path, name)
|
||||||
).to(ctx.training_device)
|
).to(ctx.training_device)
|
||||||
output_path = Path(dest_path)
|
output_path = Path(dest_path)
|
||||||
|
|
||||||
|
@ -94,14 +100,16 @@ def convert_diffusion_stable(
|
||||||
onnx_export(
|
onnx_export(
|
||||||
pipeline.text_encoder,
|
pipeline.text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
model_args=(text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)),
|
model_args=(
|
||||||
|
text_input.input_ids.to(device=ctx.training_device, dtype=torch.int32)
|
||||||
|
),
|
||||||
output_path=output_path / "text_encoder" / "model.onnx",
|
output_path=output_path / "text_encoder" / "model.onnx",
|
||||||
ordered_input_names=["input_ids"],
|
ordered_input_names=["input_ids"],
|
||||||
output_names=["last_hidden_state", "pooler_output"],
|
output_names=["last_hidden_state", "pooler_output"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"input_ids": {0: "batch", 1: "sequence"},
|
"input_ids": {0: "batch", 1: "sequence"},
|
||||||
},
|
},
|
||||||
opset=opset,
|
opset=ctx.opset,
|
||||||
)
|
)
|
||||||
del pipeline.text_encoder
|
del pipeline.text_encoder
|
||||||
|
|
||||||
|
@ -113,7 +121,9 @@ def convert_diffusion_stable(
|
||||||
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
|
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
|
||||||
else:
|
else:
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||||
unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool)
|
unet_scale = torch.tensor(False).to(
|
||||||
|
device=ctx.training_device, dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
unet_in_channels = pipeline.unet.config.in_channels
|
unet_in_channels = pipeline.unet.config.in_channels
|
||||||
unet_sample_size = pipeline.unet.config.sample_size
|
unet_sample_size = pipeline.unet.config.sample_size
|
||||||
|
@ -139,7 +149,7 @@ def convert_diffusion_stable(
|
||||||
"timestep": {0: "batch"},
|
"timestep": {0: "batch"},
|
||||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||||
},
|
},
|
||||||
opset=opset,
|
opset=ctx.opset,
|
||||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||||
)
|
)
|
||||||
unet_model_path = str(unet_path.absolute().as_posix())
|
unet_model_path = str(unet_path.absolute().as_posix())
|
||||||
|
@ -182,7 +192,7 @@ def convert_diffusion_stable(
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
},
|
},
|
||||||
opset=opset,
|
opset=ctx.opset,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# VAE ENCODER
|
# VAE ENCODER
|
||||||
|
@ -207,7 +217,7 @@ def convert_diffusion_stable(
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
},
|
},
|
||||||
opset=opset,
|
opset=ctx.opset,
|
||||||
)
|
)
|
||||||
|
|
||||||
# VAE DECODER
|
# VAE DECODER
|
||||||
|
@ -230,7 +240,7 @@ def convert_diffusion_stable(
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
},
|
},
|
||||||
opset=opset,
|
opset=ctx.opset,
|
||||||
)
|
)
|
||||||
|
|
||||||
del pipeline.vae
|
del pipeline.vae
|
||||||
|
@ -261,7 +271,7 @@ def convert_diffusion_stable(
|
||||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||||
},
|
},
|
||||||
opset=opset,
|
opset=ctx.opset,
|
||||||
)
|
)
|
||||||
del pipeline.safety_checker
|
del pipeline.safety_checker
|
||||||
safety_checker = OnnxRuntimeModel.from_pretrained(
|
safety_checker = OnnxRuntimeModel.from_pretrained(
|
||||||
|
@ -312,4 +322,3 @@ def convert_diffusion_stable(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("ONNX pipeline is loadable")
|
logger.info("ONNX pipeline is loadable")
|
||||||
|
|
||||||
|
|
|
@ -1,32 +1,34 @@
|
||||||
import torch
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
from os import path
|
|
||||||
from logging import getLogger
|
from .utils import ConversionContext, ModelDict
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from .utils import ConversionContext
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale: int, opset: int):
|
def convert_upscale_resrgan(
|
||||||
dest_path = path.join(ctx.model_path, name + ".pth")
|
ctx: ConversionContext,
|
||||||
dest_onnx = path.join(ctx.model_path, name + ".onnx")
|
model: ModelDict,
|
||||||
logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx)
|
source: str,
|
||||||
|
):
|
||||||
|
name = model.get("name")
|
||||||
|
source = source or model.get("source")
|
||||||
|
scale = model.get("scale")
|
||||||
|
|
||||||
if path.isfile(dest_onnx):
|
dest = path.join(ctx.model_path, name + ".onnx")
|
||||||
|
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
|
||||||
|
|
||||||
|
if path.isfile(dest):
|
||||||
logger.info("ONNX model already exists, skipping.")
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not path.isfile(dest_path):
|
|
||||||
logger.info("PTH model not found, downloading...")
|
|
||||||
download_path = load_file_from_url(
|
|
||||||
url=url, model_dir=dest_path + "-cache", progress=True, file_name=None
|
|
||||||
)
|
|
||||||
copyfile(download_path, dest_path)
|
|
||||||
|
|
||||||
logger.info("loading and training model")
|
logger.info("loading and training model")
|
||||||
model = RRDBNet(
|
model = RRDBNet(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
|
@ -37,7 +39,7 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_model = torch.load(dest_path, map_location=ctx.map_location)
|
torch_model = torch.load(source, map_location=ctx.map_location)
|
||||||
if "params_ema" in torch_model:
|
if "params_ema" in torch_model:
|
||||||
model.load_state_dict(torch_model["params_ema"])
|
model.load_state_dict(torch_model["params_ema"])
|
||||||
else:
|
else:
|
||||||
|
@ -54,15 +56,15 @@ def convert_upscale_resrgan(ctx: ConversionContext, name: str, url: str, scale:
|
||||||
"output": {2: "width", 3: "height"},
|
"output": {2: "width", 3: "height"},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("exporting ONNX model to %s", dest_onnx)
|
logger.info("exporting ONNX model to %s", dest)
|
||||||
export(
|
export(
|
||||||
model,
|
model,
|
||||||
rng,
|
rng,
|
||||||
dest_onnx,
|
dest,
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
opset_version=opset,
|
opset_version=ctx.opset,
|
||||||
export_params=True,
|
export_params=True,
|
||||||
)
|
)
|
||||||
logger.info("Real ESRGAN exported to ONNX successfully.")
|
logger.info("Real ESRGAN exported to ONNX successfully.")
|
||||||
|
|
|
@ -1,7 +1,129 @@
|
||||||
|
import shutil
|
||||||
|
from functools import partial
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Union, List, Optional, Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ModelDict = Dict[str, Union[str, int]]
|
||||||
|
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||||
|
|
||||||
|
|
||||||
class ConversionContext:
|
class ConversionContext:
|
||||||
def __init__(self, model_path: str, device: str) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
device: str,
|
||||||
|
cache_path: Optional[str] = None,
|
||||||
|
half: Optional[bool] = False,
|
||||||
|
opset: Optional[int] = None,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||||
self.training_device = device
|
self.training_device = device
|
||||||
self.map_location = torch.device(device)
|
self.map_location = torch.device(device)
|
||||||
|
self.half = half
|
||||||
|
self.opset = opset
|
||||||
|
self.token = token
|
||||||
|
|
||||||
|
|
||||||
|
def download_progress(urls: List[Tuple[str, str]]):
|
||||||
|
for url, dest in urls:
|
||||||
|
dest_path = Path(dest).expanduser().resolve()
|
||||||
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if dest_path.exists():
|
||||||
|
logger.info("Destination already exists: %s", dest_path)
|
||||||
|
return str(dest_path.absolute())
|
||||||
|
|
||||||
|
req = requests.get(url, stream=True, allow_redirects=True)
|
||||||
|
if req.status_code != 200:
|
||||||
|
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||||
|
raise RuntimeError(
|
||||||
|
"Request to %s failed with status code: %s" % (url, req.status_code)
|
||||||
|
)
|
||||||
|
|
||||||
|
total = int(req.headers.get("Content-Length", 0))
|
||||||
|
desc = "unknown" if total == 0 else ""
|
||||||
|
req.raw.read = partial(req.raw.read, decode_content=True)
|
||||||
|
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
|
||||||
|
with dest_path.open("wb") as f:
|
||||||
|
shutil.copyfileobj(data, f)
|
||||||
|
|
||||||
|
return str(dest_path.absolute())
|
||||||
|
|
||||||
|
|
||||||
|
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||||
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
|
name, source, *rest = model
|
||||||
|
scale = rest[0] if len(rest) > 0 else 1
|
||||||
|
half = rest[0] if len(rest) > 0 else False
|
||||||
|
opset = rest[0] if len(rest) > 0 else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"source": source,
|
||||||
|
"half": half,
|
||||||
|
"opset": opset,
|
||||||
|
"scale": scale,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
||||||
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
|
name, source, *rest = model
|
||||||
|
single_vae = rest[0] if len(rest) > 0 else False
|
||||||
|
half = rest[0] if len(rest) > 0 else False
|
||||||
|
opset = rest[0] if len(rest) > 0 else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"source": source,
|
||||||
|
"half": half,
|
||||||
|
"opset": opset,
|
||||||
|
"single_vae": single_vae,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
|
name, source, *rest = model
|
||||||
|
scale = rest[0] if len(rest) > 0 else 1
|
||||||
|
half = rest[0] if len(rest) > 0 else False
|
||||||
|
opset = rest[0] if len(rest) > 0 else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"source": source,
|
||||||
|
"half": half,
|
||||||
|
"opset": opset,
|
||||||
|
"scale": scale,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||||
|
|
||||||
|
def source_format(model: Dict) -> Optional[str]:
|
||||||
|
if "format" in model:
|
||||||
|
return model["format"]
|
||||||
|
|
||||||
|
if "source" in model:
|
||||||
|
ext = path.splitext(model["source"])
|
||||||
|
if ext in known_formats:
|
||||||
|
return ext
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -649,10 +649,10 @@ def chain():
|
||||||
return error_reply("chain pipeline must have a body")
|
return error_reply("chain pipeline must have a body")
|
||||||
|
|
||||||
data = yaml.safe_load(body)
|
data = yaml.safe_load(body)
|
||||||
with open("./schema.yaml", "r") as f:
|
with open("./schemas/chain.yaml", "r") as f:
|
||||||
schema = yaml.safe_load(f.read())
|
schema = yaml.safe_load(f.read())
|
||||||
|
|
||||||
logger.info("validating chain request: %s against %s", data, schema)
|
logger.debug("validating chain request: %s against %s", data, schema)
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
# get defaults from the regular parameters
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[tool.black]
|
[tool.black]
|
||||||
force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py'''
|
force-exclude = '''/(diffusion_original|lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_upscale).py'''
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml
|
$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/chain.yaml
|
||||||
$schema: https://json-schema.org/draft/2020-12/schema
|
$schema: https://json-schema.org/draft/2020-12/schema
|
||||||
|
|
||||||
$defs:
|
$defs:
|
|
@ -0,0 +1,65 @@
|
||||||
|
$id: https://github.com/ssube/onnx-web/blob/main/api/schemas/extras.yaml
|
||||||
|
$schema: https://json-schema.org/draft/2020-12/schema
|
||||||
|
|
||||||
|
$defs:
|
||||||
|
legacy_tuple:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: number
|
||||||
|
|
||||||
|
base_model:
|
||||||
|
type: object
|
||||||
|
required: [name, source]
|
||||||
|
properties:
|
||||||
|
format:
|
||||||
|
type: string
|
||||||
|
enum: [onnx, pth, ckpt, safetensors]
|
||||||
|
half:
|
||||||
|
type: boolean
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
opset:
|
||||||
|
type: number
|
||||||
|
source:
|
||||||
|
type: string
|
||||||
|
|
||||||
|
correction_model:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/$defs/base_model"
|
||||||
|
|
||||||
|
diffusion_model:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/$defs/base_model"
|
||||||
|
|
||||||
|
upscaling_model:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/$defs/base_model"
|
||||||
|
- type: object
|
||||||
|
required: [scale]
|
||||||
|
properties:
|
||||||
|
scale:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
type: object
|
||||||
|
additionalProperties: False
|
||||||
|
properties:
|
||||||
|
diffusion:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- $ref: "#/$defs/legacy_tuple"
|
||||||
|
- $ref: "#/$defs/diffusion_model"
|
||||||
|
correction:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- $ref: "#/$defs/legacy_tuple"
|
||||||
|
- $ref: "#/$defs/correction_model"
|
||||||
|
upscaling:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- $ref: "#/$defs/legacy_tuple"
|
||||||
|
- $ref: "#/$defs/upscaling_model"
|
|
@ -58,6 +58,7 @@
|
||||||
"rocm",
|
"rocm",
|
||||||
"RRDB",
|
"RRDB",
|
||||||
"runwayml",
|
"runwayml",
|
||||||
|
"safetensors",
|
||||||
"scandir",
|
"scandir",
|
||||||
"scipy",
|
"scipy",
|
||||||
"scrollback",
|
"scrollback",
|
||||||
|
|
Loading…
Reference in New Issue