lint(api): apply to original diffusers converter
This commit is contained in:
parent
dbee258a36
commit
c599385a30
|
@ -1,20 +1,26 @@
|
|||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from json import loads
|
||||
from logging import getLogger
|
||||
from os import environ, makedirs, path
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from yaml import safe_load
|
||||
from jsonschema import validate, ValidationError
|
||||
|
||||
import torch
|
||||
from jsonschema import ValidationError, validate
|
||||
from yaml import safe_load
|
||||
|
||||
from .correction_gfpgan import convert_correction_gfpgan
|
||||
from .diffusion_original import convert_diffusion_original
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .upscale_resrgan import convert_upscale_resrgan
|
||||
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling
|
||||
from .utils import (
|
||||
ConversionContext,
|
||||
download_progress,
|
||||
source_format,
|
||||
tuple_to_correction,
|
||||
tuple_to_diffusion,
|
||||
tuple_to_upscaling,
|
||||
)
|
||||
|
||||
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
||||
warnings.filterwarnings(
|
||||
|
@ -100,7 +106,9 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
|||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
|
||||
def fetch_model(
|
||||
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
|
||||
) -> str:
|
||||
cache_name = path.join(ctx.cache_path, name)
|
||||
if format is not None:
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
|
@ -110,7 +118,9 @@ def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional
|
|||
api_name, api_root = model_sources.get(proto)
|
||||
if source.startswith(proto):
|
||||
api_source = api_root % (source.removeprefix(proto))
|
||||
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
|
||||
logger.info(
|
||||
"Downloading model from %s: %s -> %s", api_name, api_source, cache_name
|
||||
)
|
||||
return download_progress([(api_source, cache_name)])
|
||||
|
||||
if source.startswith(model_source_huggingface):
|
||||
|
@ -218,7 +228,9 @@ def main() -> int:
|
|||
args = parser.parse_args()
|
||||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
|
||||
ctx = ConversionContext(
|
||||
model_path, training_device, half=args.half, opset=args.opset, token=args.token
|
||||
)
|
||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||
|
||||
if not path.exists(model_path):
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.onnx import export
|
||||
|
||||
from .utils import ConversionContext, ModelDict
|
||||
|
|
|
@ -959,6 +959,7 @@ def replace_symlinks(path, base):
|
|||
for subpath in os.listdir(path):
|
||||
replace_symlinks(os.path.join(path, subpath), base)
|
||||
|
||||
|
||||
def download_model(db_config: DreamboothConfig, token):
|
||||
tmp_dir = os.path.join(db_config.model_dir, "src")
|
||||
|
||||
|
@ -994,7 +995,7 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
if "model_index.json" in name:
|
||||
model_index = name
|
||||
continue
|
||||
if (".ckpt" in name or ".safetensors" in name) and not "/" in name:
|
||||
if (".ckpt" in name or ".safetensors" in name) and "/" not in name:
|
||||
model_files.append(name)
|
||||
continue
|
||||
for diffusion_dir in diffusion_dirs:
|
||||
|
@ -1034,8 +1035,6 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
logger.debug("Nothing to fetch!")
|
||||
return None, None
|
||||
|
||||
|
||||
# huggingface_hub.utils.tqdm.tqdm = mytqdm
|
||||
mytqdm = huggingface_hub.utils.tqdm.tqdm
|
||||
out_model = None
|
||||
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
|
||||
|
@ -1074,6 +1073,7 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
|
||||
return out_model, config_file
|
||||
|
||||
|
||||
def get_config_path(
|
||||
model_version: str = "v1",
|
||||
train_type: str = "default",
|
||||
|
@ -1093,6 +1093,7 @@ def get_config_path(
|
|||
)
|
||||
return os.path.abspath(parts)
|
||||
|
||||
|
||||
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
||||
|
||||
config_base_name = "training"
|
||||
|
@ -1117,8 +1118,18 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
|
|||
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
|
||||
|
||||
|
||||
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
|
||||
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
|
||||
def extract_checkpoint(
|
||||
ctx: ConversionContext,
|
||||
new_model_name: str,
|
||||
checkpoint_file: str,
|
||||
scheduler_type="ddim",
|
||||
from_hub=False,
|
||||
new_model_url="",
|
||||
new_model_token="",
|
||||
extract_ema=False,
|
||||
train_unfrozen=False,
|
||||
is_512=True,
|
||||
):
|
||||
"""
|
||||
|
||||
@param new_model_name: The name of the new model
|
||||
|
@ -1153,9 +1164,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
msg = None
|
||||
|
||||
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
|
||||
msg = "Please provide a URL and token for huggingface models."
|
||||
if msg is not None:
|
||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
||||
logger.warning("Please provide a URL and token for huggingface models.")
|
||||
return
|
||||
|
||||
# Create empty config
|
||||
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||
|
@ -1178,7 +1188,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
else:
|
||||
msg = "Unable to fetch model from hub."
|
||||
logger.warning(msg)
|
||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
||||
return
|
||||
|
||||
try:
|
||||
checkpoint = None
|
||||
|
@ -1194,6 +1204,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
logger.debug("Loading safetensors...")
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
except Exception as e:
|
||||
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
|
||||
checkpoint = torch.jit.load(checkpoint_file)
|
||||
else:
|
||||
logger.debug("Loading ckpt...")
|
||||
|
@ -1231,10 +1242,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
||||
logger.debug("UNet using v2 parameters.")
|
||||
v2 = True
|
||||
|
||||
except:
|
||||
logger.error("Exception loading unet!")
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
logger.error("Exception loading unet!", traceback.format_exception(e))
|
||||
|
||||
if v2 and not is_512:
|
||||
prediction_type = "v_prediction"
|
||||
|
@ -1263,7 +1272,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
|
||||
if original_config_file is None or not os.path.exists(original_config_file):
|
||||
logger.warning("Unable to select a config file: %s" % (original_config_file))
|
||||
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
|
||||
return
|
||||
|
||||
logger.debug(f"Trying to load: {original_config_file}")
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
@ -1303,9 +1312,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
|
||||
logger.info("Converting UNet...")
|
||||
# Convert the UNet2DConditionModel model.
|
||||
logger.info("Converting UNet...")
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
@ -1317,16 +1325,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
db_config.save()
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
logger.info("Converting VAE...")
|
||||
# Convert the VAE model.
|
||||
logger.info("Converting VAE...")
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
logger.info("Converting text encoder...")
|
||||
# Convert the text model.
|
||||
logger.info("Converting text encoder...")
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
|
@ -1374,23 +1382,19 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||
scheduler=scheduler)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception setting up output: {e}")
|
||||
logger.error("Exception setting up output: %s", traceback.format_exception(e))
|
||||
pipe = None
|
||||
traceback.print_exc()
|
||||
|
||||
if pipe is None or db_config is None:
|
||||
msg = "Pipeline or config is not set, unable to continue."
|
||||
logger.error(msg)
|
||||
return "", "", 0, 0, "", "", "", "", image_size, "", msg
|
||||
return
|
||||
else:
|
||||
resolution = db_config.resolution
|
||||
logger.info("Saving diffusion model...")
|
||||
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
|
||||
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
|
||||
model_dir = db_config.model_dir
|
||||
revision = db_config.revision
|
||||
scheduler = db_config.scheduler
|
||||
src = db_config.src
|
||||
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
||||
if original_config_file is not None and os.path.exists(original_config_file):
|
||||
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||
|
@ -1417,11 +1421,11 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
|
|||
if not os.path.exists(rem_dir):
|
||||
os.makedirs(rem_dir)
|
||||
|
||||
|
||||
logger.info(result_status)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def convert_diffusion_original(
|
||||
ctx: ConversionContext,
|
||||
model: ModelDict,
|
||||
|
|
|
@ -65,7 +65,9 @@ def convert_diffusion_stable(
|
|||
dest_path = path.join(ctx.model_path, name)
|
||||
|
||||
# diffusers go into a directory rather than .onnx file
|
||||
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path)
|
||||
logger.info(
|
||||
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
|
||||
)
|
||||
|
||||
if single_vae:
|
||||
logger.info("converting model with single VAE")
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.onnx import export
|
||||
|
||||
from .utils import ConversionContext, ModelDict
|
||||
|
|
|
@ -3,7 +3,7 @@ from functools import partial
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
@ -44,7 +44,14 @@ def download_progress(urls: List[Tuple[str, str]]):
|
|||
logger.info("Destination already exists: %s", dest_path)
|
||||
return str(dest_path.absolute())
|
||||
|
||||
req = requests.get(url, stream=True, allow_redirects=True)
|
||||
req = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
headers={
|
||||
"User-Agent": "onnx-web-api",
|
||||
},
|
||||
)
|
||||
if req.status_code != 200:
|
||||
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||
raise RuntimeError(
|
||||
|
@ -117,6 +124,7 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|||
|
||||
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
|
||||
|
||||
def source_format(model: Dict) -> Optional[str]:
|
||||
if "format" in model:
|
||||
return model["format"]
|
||||
|
|
Loading…
Reference in New Issue