lint(api): apply to original diffusers converter
This commit is contained in:
parent
dbee258a36
commit
c599385a30
|
@ -1,20 +1,26 @@
|
||||||
import warnings
|
import warnings
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from json import loads
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, makedirs, path
|
from os import environ, makedirs, path
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from yaml import safe_load
|
|
||||||
from jsonschema import validate, ValidationError
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from jsonschema import ValidationError, validate
|
||||||
|
from yaml import safe_load
|
||||||
|
|
||||||
from .correction_gfpgan import convert_correction_gfpgan
|
from .correction_gfpgan import convert_correction_gfpgan
|
||||||
from .diffusion_original import convert_diffusion_original
|
from .diffusion_original import convert_diffusion_original
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
from .diffusion_stable import convert_diffusion_stable
|
||||||
from .upscale_resrgan import convert_upscale_resrgan
|
from .upscale_resrgan import convert_upscale_resrgan
|
||||||
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling
|
from .utils import (
|
||||||
|
ConversionContext,
|
||||||
|
download_progress,
|
||||||
|
source_format,
|
||||||
|
tuple_to_correction,
|
||||||
|
tuple_to_diffusion,
|
||||||
|
tuple_to_upscaling,
|
||||||
|
)
|
||||||
|
|
||||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
|
@ -100,7 +106,9 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
|
def fetch_model(
|
||||||
|
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
cache_name = path.join(ctx.cache_path, name)
|
cache_name = path.join(ctx.cache_path, name)
|
||||||
if format is not None:
|
if format is not None:
|
||||||
# add an extension if possible, some of the conversion code checks for it
|
# add an extension if possible, some of the conversion code checks for it
|
||||||
|
@ -110,7 +118,9 @@ def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional
|
||||||
api_name, api_root = model_sources.get(proto)
|
api_name, api_root = model_sources.get(proto)
|
||||||
if source.startswith(proto):
|
if source.startswith(proto):
|
||||||
api_source = api_root % (source.removeprefix(proto))
|
api_source = api_root % (source.removeprefix(proto))
|
||||||
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
|
logger.info(
|
||||||
|
"Downloading model from %s: %s -> %s", api_name, api_source, cache_name
|
||||||
|
)
|
||||||
return download_progress([(api_source, cache_name)])
|
return download_progress([(api_source, cache_name)])
|
||||||
|
|
||||||
if source.startswith(model_source_huggingface):
|
if source.startswith(model_source_huggingface):
|
||||||
|
@ -218,7 +228,9 @@ def main() -> int:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("CLI arguments: %s", args)
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
|
ctx = ConversionContext(
|
||||||
|
model_path, training_device, half=args.half, opset=args.opset, token=args.token
|
||||||
|
)
|
||||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||||
|
|
||||||
if not path.exists(model_path):
|
if not path.exists(model_path):
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from shutil import copyfile
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
from .utils import ConversionContext, ModelDict
|
from .utils import ConversionContext, ModelDict
|
||||||
|
|
|
@ -959,6 +959,7 @@ def replace_symlinks(path, base):
|
||||||
for subpath in os.listdir(path):
|
for subpath in os.listdir(path):
|
||||||
replace_symlinks(os.path.join(path, subpath), base)
|
replace_symlinks(os.path.join(path, subpath), base)
|
||||||
|
|
||||||
|
|
||||||
def download_model(db_config: DreamboothConfig, token):
|
def download_model(db_config: DreamboothConfig, token):
|
||||||
tmp_dir = os.path.join(db_config.model_dir, "src")
|
tmp_dir = os.path.join(db_config.model_dir, "src")
|
||||||
|
|
||||||
|
@ -994,7 +995,7 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
if "model_index.json" in name:
|
if "model_index.json" in name:
|
||||||
model_index = name
|
model_index = name
|
||||||
continue
|
continue
|
||||||
if (".ckpt" in name or ".safetensors" in name) and not "/" in name:
|
if (".ckpt" in name or ".safetensors" in name) and "/" not in name:
|
||||||
model_files.append(name)
|
model_files.append(name)
|
||||||
continue
|
continue
|
||||||
for diffusion_dir in diffusion_dirs:
|
for diffusion_dir in diffusion_dirs:
|
||||||
|
@ -1034,8 +1035,6 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
logger.debug("Nothing to fetch!")
|
logger.debug("Nothing to fetch!")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
# huggingface_hub.utils.tqdm.tqdm = mytqdm
|
|
||||||
mytqdm = huggingface_hub.utils.tqdm.tqdm
|
mytqdm = huggingface_hub.utils.tqdm.tqdm
|
||||||
out_model = None
|
out_model = None
|
||||||
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
|
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
|
||||||
|
@ -1058,7 +1057,7 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
for diffusion_dir in diffusion_dirs:
|
for diffusion_dir in diffusion_dirs:
|
||||||
if diffusion_dir in out:
|
if diffusion_dir in out:
|
||||||
out_model = db_config.pretrained_model_name_or_path
|
out_model = db_config.pretrained_model_name_or_path
|
||||||
dest = os.path.join(db_config.pretrained_model_name_or_path,diffusion_dir)
|
dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir)
|
||||||
if not dest:
|
if not dest:
|
||||||
if ".ckpt" in out or ".safetensors" in out:
|
if ".ckpt" in out or ".safetensors" in out:
|
||||||
dest = os.path.join(db_config.model_dir, "src")
|
dest = os.path.join(db_config.model_dir, "src")
|
||||||
|
@ -1074,12 +1073,13 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
|
|
||||||
return out_model, config_file
|
return out_model, config_file
|
||||||
|
|
||||||
|
|
||||||
def get_config_path(
|
def get_config_path(
|
||||||
model_version: str = "v1",
|
model_version: str = "v1",
|
||||||
train_type: str = "default",
|
train_type: str = "default",
|
||||||
config_base_name: str = "training",
|
config_base_name: str = "training",
|
||||||
prediction_type: str = "epsilon"
|
prediction_type: str = "epsilon"
|
||||||
):
|
):
|
||||||
train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
|
train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
|
||||||
|
|
||||||
parts = os.path.join(
|
parts = os.path.join(
|
||||||
|
@ -1093,6 +1093,7 @@ def get_config_path(
|
||||||
)
|
)
|
||||||
return os.path.abspath(parts)
|
return os.path.abspath(parts)
|
||||||
|
|
||||||
|
|
||||||
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
||||||
|
|
||||||
config_base_name = "training"
|
config_base_name = "training"
|
||||||
|
@ -1117,8 +1118,18 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
||||||
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
|
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
|
||||||
|
|
||||||
|
|
||||||
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
|
def extract_checkpoint(
|
||||||
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
|
ctx: ConversionContext,
|
||||||
|
new_model_name: str,
|
||||||
|
checkpoint_file: str,
|
||||||
|
scheduler_type="ddim",
|
||||||
|
from_hub=False,
|
||||||
|
new_model_url="",
|
||||||
|
new_model_token="",
|
||||||
|
extract_ema=False,
|
||||||
|
train_unfrozen=False,
|
||||||
|
is_512=True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@param new_model_name: The name of the new model
|
@param new_model_name: The name of the new model
|
||||||
|
@ -1153,9 +1164,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
msg = None
|
msg = None
|
||||||
|
|
||||||
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
|
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
|
||||||
msg = "Please provide a URL and token for huggingface models."
|
logger.warning("Please provide a URL and token for huggingface models.")
|
||||||
if msg is not None:
|
return
|
||||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
|
||||||
|
|
||||||
# Create empty config
|
# Create empty config
|
||||||
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||||
|
@ -1178,7 +1188,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
else:
|
else:
|
||||||
msg = "Unable to fetch model from hub."
|
msg = "Unable to fetch model from hub."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
|
@ -1194,6 +1204,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
logger.debug("Loading safetensors...")
|
logger.debug("Loading safetensors...")
|
||||||
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
|
||||||
checkpoint = torch.jit.load(checkpoint_file)
|
checkpoint = torch.jit.load(checkpoint_file)
|
||||||
else:
|
else:
|
||||||
logger.debug("Loading ckpt...")
|
logger.debug("Loading ckpt...")
|
||||||
|
@ -1231,10 +1242,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
||||||
logger.debug("UNet using v2 parameters.")
|
logger.debug("UNet using v2 parameters.")
|
||||||
v2 = True
|
v2 = True
|
||||||
|
except Exception as e:
|
||||||
except:
|
logger.error("Exception loading unet!", traceback.format_exception(e))
|
||||||
logger.error("Exception loading unet!")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if v2 and not is_512:
|
if v2 and not is_512:
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
|
@ -1263,7 +1272,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
|
|
||||||
if original_config_file is None or not os.path.exists(original_config_file):
|
if original_config_file is None or not os.path.exists(original_config_file):
|
||||||
logger.warning("Unable to select a config file: %s" % (original_config_file))
|
logger.warning("Unable to select a config file: %s" % (original_config_file))
|
||||||
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
|
return
|
||||||
|
|
||||||
logger.debug(f"Trying to load: {original_config_file}")
|
logger.debug(f"Trying to load: {original_config_file}")
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
@ -1303,9 +1312,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||||
|
|
||||||
|
|
||||||
logger.info("Converting UNet...")
|
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
|
logger.info("Converting UNet...")
|
||||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||||
unet_config["upcast_attention"] = upcast_attention
|
unet_config["upcast_attention"] = upcast_attention
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
|
@ -1317,16 +1325,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
db_config.save()
|
db_config.save()
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
||||||
logger.info("Converting VAE...")
|
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
|
logger.info("Converting VAE...")
|
||||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
|
||||||
logger.info("Converting text encoder...")
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
|
logger.info("Converting text encoder...")
|
||||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
|
@ -1374,23 +1382,19 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||||
scheduler=scheduler)
|
scheduler=scheduler)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception setting up output: {e}")
|
logger.error("Exception setting up output: %s", traceback.format_exception(e))
|
||||||
pipe = None
|
pipe = None
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if pipe is None or db_config is None:
|
if pipe is None or db_config is None:
|
||||||
msg = "Pipeline or config is not set, unable to continue."
|
msg = "Pipeline or config is not set, unable to continue."
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
return
|
||||||
else:
|
else:
|
||||||
resolution = db_config.resolution
|
|
||||||
logger.info("Saving diffusion model...")
|
logger.info("Saving diffusion model...")
|
||||||
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
|
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
|
||||||
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
|
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
|
||||||
model_dir = db_config.model_dir
|
|
||||||
revision = db_config.revision
|
revision = db_config.revision
|
||||||
scheduler = db_config.scheduler
|
scheduler = db_config.scheduler
|
||||||
src = db_config.src
|
|
||||||
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
||||||
if original_config_file is not None and os.path.exists(original_config_file):
|
if original_config_file is not None and os.path.exists(original_config_file):
|
||||||
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||||
|
@ -1417,11 +1421,11 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
||||||
if not os.path.exists(rem_dir):
|
if not os.path.exists(rem_dir):
|
||||||
os.makedirs(rem_dir)
|
os.makedirs(rem_dir)
|
||||||
|
|
||||||
|
|
||||||
logger.info(result_status)
|
logger.info(result_status)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusion_original(
|
def convert_diffusion_original(
|
||||||
ctx: ConversionContext,
|
ctx: ConversionContext,
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
|
|
|
@ -65,7 +65,9 @@ def convert_diffusion_stable(
|
||||||
dest_path = path.join(ctx.model_path, name)
|
dest_path = path.join(ctx.model_path, name)
|
||||||
|
|
||||||
# diffusers go into a directory rather than .onnx file
|
# diffusers go into a directory rather than .onnx file
|
||||||
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path)
|
logger.info(
|
||||||
|
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
|
||||||
|
)
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.info("converting model with single VAE")
|
logger.info("converting model with single VAE")
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from shutil import copyfile
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
from .utils import ConversionContext, ModelDict
|
from .utils import ConversionContext, ModelDict
|
||||||
|
|
|
@ -3,7 +3,7 @@ from functools import partial
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
@ -44,7 +44,14 @@ def download_progress(urls: List[Tuple[str, str]]):
|
||||||
logger.info("Destination already exists: %s", dest_path)
|
logger.info("Destination already exists: %s", dest_path)
|
||||||
return str(dest_path.absolute())
|
return str(dest_path.absolute())
|
||||||
|
|
||||||
req = requests.get(url, stream=True, allow_redirects=True)
|
req = requests.get(
|
||||||
|
url,
|
||||||
|
stream=True,
|
||||||
|
allow_redirects=True,
|
||||||
|
headers={
|
||||||
|
"User-Agent": "onnx-web-api",
|
||||||
|
},
|
||||||
|
)
|
||||||
if req.status_code != 200:
|
if req.status_code != 200:
|
||||||
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -117,6 +124,7 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
|
|
||||||
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||||
|
|
||||||
|
|
||||||
def source_format(model: Dict) -> Optional[str]:
|
def source_format(model: Dict) -> Optional[str]:
|
||||||
if "format" in model:
|
if "format" in model:
|
||||||
return model["format"]
|
return model["format"]
|
||||||
|
|
Loading…
Reference in New Issue