1
0
Fork 0

lint(api): apply to original diffusers converter

This commit is contained in:
Sean Sube 2023-02-10 23:32:16 -06:00
parent dbee258a36
commit c599385a30
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 67 additions and 45 deletions

View File

@ -1,20 +1,26 @@
import warnings
from argparse import ArgumentParser
from json import loads
from logging import getLogger
from os import environ, makedirs, path
from sys import exit
from typing import Dict, List, Optional, Tuple
from yaml import safe_load
from jsonschema import validate, ValidationError
import torch
from jsonschema import ValidationError, validate
from yaml import safe_load
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling
from .utils import (
ConversionContext,
download_progress,
source_format,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_upscaling,
)
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings(
@ -100,7 +106,9 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu"
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str:
def fetch_model(
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
) -> str:
cache_name = path.join(ctx.cache_path, name)
if format is not None:
# add an extension if possible, some of the conversion code checks for it
@ -110,7 +118,9 @@ def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional
api_name, api_root = model_sources.get(proto)
if source.startswith(proto):
api_source = api_root % (source.removeprefix(proto))
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name)
logger.info(
"Downloading model from %s: %s -> %s", api_name, api_source, cache_name
)
return download_progress([(api_source, cache_name)])
if source.startswith(model_source_huggingface):
@ -218,7 +228,9 @@ def main() -> int:
args = parser.parse_args()
logger.info("CLI arguments: %s", args)
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token)
ctx = ConversionContext(
model_path, training_device, half=args.half, opset=args.opset, token=args.token
)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if not path.exists(model_path):

View File

@ -1,10 +1,8 @@
from logging import getLogger
from os import path
from shutil import copyfile
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export
from .utils import ConversionContext, ModelDict

View File

@ -959,6 +959,7 @@ def replace_symlinks(path, base):
for subpath in os.listdir(path):
replace_symlinks(os.path.join(path, subpath), base)
def download_model(db_config: DreamboothConfig, token):
tmp_dir = os.path.join(db_config.model_dir, "src")
@ -994,7 +995,7 @@ def download_model(db_config: DreamboothConfig, token):
if "model_index.json" in name:
model_index = name
continue
if (".ckpt" in name or ".safetensors" in name) and not "/" in name:
if (".ckpt" in name or ".safetensors" in name) and "/" not in name:
model_files.append(name)
continue
for diffusion_dir in diffusion_dirs:
@ -1034,8 +1035,6 @@ def download_model(db_config: DreamboothConfig, token):
logger.debug("Nothing to fetch!")
return None, None
# huggingface_hub.utils.tqdm.tqdm = mytqdm
mytqdm = huggingface_hub.utils.tqdm.tqdm
out_model = None
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
@ -1058,7 +1057,7 @@ def download_model(db_config: DreamboothConfig, token):
for diffusion_dir in diffusion_dirs:
if diffusion_dir in out:
out_model = db_config.pretrained_model_name_or_path
dest = os.path.join(db_config.pretrained_model_name_or_path,diffusion_dir)
dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir)
if not dest:
if ".ckpt" in out or ".safetensors" in out:
dest = os.path.join(db_config.model_dir, "src")
@ -1074,12 +1073,13 @@ def download_model(db_config: DreamboothConfig, token):
return out_model, config_file
def get_config_path(
model_version: str = "v1",
train_type: str = "default",
config_base_name: str = "training",
prediction_type: str = "epsilon"
):
):
train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
parts = os.path.join(
@ -1093,6 +1093,7 @@ def get_config_path(
)
return os.path.abspath(parts)
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
config_base_name = "training"
@ -1117,8 +1118,18 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="",
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True):
def extract_checkpoint(
ctx: ConversionContext,
new_model_name: str,
checkpoint_file: str,
scheduler_type="ddim",
from_hub=False,
new_model_url="",
new_model_token="",
extract_ema=False,
train_unfrozen=False,
is_512=True,
):
"""
@param new_model_name: The name of the new model
@ -1153,9 +1164,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
msg = None
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
msg = "Please provide a URL and token for huggingface models."
if msg is not None:
return "", "", 0, 0, "", "", "", "", image_size, "", msg
logger.warning("Please provide a URL and token for huggingface models.")
return
# Create empty config
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
@ -1178,7 +1188,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
else:
msg = "Unable to fetch model from hub."
logger.warning(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg
return
try:
checkpoint = None
@ -1194,6 +1204,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
logger.debug("Loading safetensors...")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
checkpoint = torch.jit.load(checkpoint_file)
else:
logger.debug("Loading ckpt...")
@ -1231,10 +1242,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
logger.debug("UNet using v2 parameters.")
v2 = True
except:
logger.error("Exception loading unet!")
traceback.print_exc()
except Exception as e:
logger.error("Exception loading unet!", traceback.format_exception(e))
if v2 and not is_512:
prediction_type = "v_prediction"
@ -1263,7 +1272,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if original_config_file is None or not os.path.exists(original_config_file):
logger.warning("Unable to select a config file: %s" % (original_config_file))
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file."
return
logger.debug(f"Trying to load: {original_config_file}")
original_config = OmegaConf.load(original_config_file)
@ -1303,9 +1312,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
logger.info("Converting UNet...")
# Convert the UNet2DConditionModel model.
logger.info("Converting UNet...")
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config)
@ -1317,16 +1325,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
db_config.save()
unet.load_state_dict(converted_unet_checkpoint)
logger.info("Converting VAE...")
# Convert the VAE model.
logger.info("Converting VAE...")
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
logger.info("Converting text encoder...")
# Convert the text model.
logger.info("Converting text encoder...")
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
@ -1374,23 +1382,19 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler)
except Exception as e:
logger.error(f"Exception setting up output: {e}")
logger.error("Exception setting up output: %s", traceback.format_exception(e))
pipe = None
traceback.print_exc()
if pipe is None or db_config is None:
msg = "Pipeline or config is not set, unable to continue."
logger.error(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg
return
else:
resolution = db_config.resolution
logger.info("Saving diffusion model...")
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
model_dir = db_config.model_dir
revision = db_config.revision
scheduler = db_config.scheduler
src = db_config.src
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file):
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
@ -1417,11 +1421,11 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if not os.path.exists(rem_dir):
os.makedirs(rem_dir)
logger.info(result_status)
return
def convert_diffusion_original(
ctx: ConversionContext,
model: ModelDict,

View File

@ -65,7 +65,9 @@ def convert_diffusion_stable(
dest_path = path.join(ctx.model_path, name)
# diffusers go into a directory rather than .onnx file
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path)
logger.info(
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
)
if single_vae:
logger.info("converting model with single VAE")

View File

@ -1,10 +1,8 @@
from logging import getLogger
from os import path
from shutil import copyfile
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export
from .utils import ConversionContext, ModelDict

View File

@ -3,7 +3,7 @@ from functools import partial
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict, Union, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import requests
import torch
@ -44,7 +44,14 @@ def download_progress(urls: List[Tuple[str, str]]):
logger.info("Destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(url, stream=True, allow_redirects=True)
req = requests.get(
url,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RuntimeError(
@ -117,6 +124,7 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
known_formats = ["onnx", "pth", "ckpt", "safetensors"]
def source_format(model: Dict) -> Optional[str]:
if "format" in model:
return model["format"]