1
0
Fork 0

lint(api): apply to original diffusers converter

This commit is contained in:
Sean Sube 2023-02-10 23:32:16 -06:00
parent dbee258a36
commit c599385a30
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 67 additions and 45 deletions

View File

@ -1,20 +1,26 @@
import warnings import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from json import loads
from logging import getLogger from logging import getLogger
from os import environ, makedirs, path from os import environ, makedirs, path
from sys import exit from sys import exit
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from yaml import safe_load
from jsonschema import validate, ValidationError
import torch import torch
from jsonschema import ValidationError, validate
from yaml import safe_load
from .correction_gfpgan import convert_correction_gfpgan from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable from .diffusion_stable import convert_diffusion_stable
from .upscale_resrgan import convert_upscale_resrgan from .upscale_resrgan import convert_upscale_resrgan
from .utils import ConversionContext, download_progress, source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_upscaling from .utils import (
ConversionContext,
download_progress,
source_format,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_upscaling,
)
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 # suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
warnings.filterwarnings( warnings.filterwarnings(
@ -100,7 +106,9 @@ model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu" training_device = "cuda" if torch.cuda.is_available() else "cpu"
def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional[str] = None) -> str: def fetch_model(
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
) -> str:
cache_name = path.join(ctx.cache_path, name) cache_name = path.join(ctx.cache_path, name)
if format is not None: if format is not None:
# add an extension if possible, some of the conversion code checks for it # add an extension if possible, some of the conversion code checks for it
@ -110,7 +118,9 @@ def fetch_model(ctx: ConversionContext, name: str, source: str, format: Optional
api_name, api_root = model_sources.get(proto) api_name, api_root = model_sources.get(proto)
if source.startswith(proto): if source.startswith(proto):
api_source = api_root % (source.removeprefix(proto)) api_source = api_root % (source.removeprefix(proto))
logger.info("Downloading model from %s: %s -> %s", api_name, api_source, cache_name) logger.info(
"Downloading model from %s: %s -> %s", api_name, api_source, cache_name
)
return download_progress([(api_source, cache_name)]) return download_progress([(api_source, cache_name)])
if source.startswith(model_source_huggingface): if source.startswith(model_source_huggingface):
@ -218,7 +228,9 @@ def main() -> int:
args = parser.parse_args() args = parser.parse_args()
logger.info("CLI arguments: %s", args) logger.info("CLI arguments: %s", args)
ctx = ConversionContext(model_path, training_device, half=args.half, opset=args.opset, token=args.token) ctx = ConversionContext(
model_path, training_device, half=args.half, opset=args.opset, token=args.token
)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if not path.exists(model_path): if not path.exists(model_path):

View File

@ -1,10 +1,8 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from shutil import copyfile
import torch import torch
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export from torch.onnx import export
from .utils import ConversionContext, ModelDict from .utils import ConversionContext, ModelDict

View File

@ -959,6 +959,7 @@ def replace_symlinks(path, base):
for subpath in os.listdir(path): for subpath in os.listdir(path):
replace_symlinks(os.path.join(path, subpath), base) replace_symlinks(os.path.join(path, subpath), base)
def download_model(db_config: DreamboothConfig, token): def download_model(db_config: DreamboothConfig, token):
tmp_dir = os.path.join(db_config.model_dir, "src") tmp_dir = os.path.join(db_config.model_dir, "src")
@ -994,7 +995,7 @@ def download_model(db_config: DreamboothConfig, token):
if "model_index.json" in name: if "model_index.json" in name:
model_index = name model_index = name
continue continue
if (".ckpt" in name or ".safetensors" in name) and not "/" in name: if (".ckpt" in name or ".safetensors" in name) and "/" not in name:
model_files.append(name) model_files.append(name)
continue continue
for diffusion_dir in diffusion_dirs: for diffusion_dir in diffusion_dirs:
@ -1034,8 +1035,6 @@ def download_model(db_config: DreamboothConfig, token):
logger.debug("Nothing to fetch!") logger.debug("Nothing to fetch!")
return None, None return None, None
# huggingface_hub.utils.tqdm.tqdm = mytqdm
mytqdm = huggingface_hub.utils.tqdm.tqdm mytqdm = huggingface_hub.utils.tqdm.tqdm
out_model = None out_model = None
for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"): for repo_file in mytqdm(files_to_fetch, desc=f"Fetching {len(files_to_fetch)} files"):
@ -1058,7 +1057,7 @@ def download_model(db_config: DreamboothConfig, token):
for diffusion_dir in diffusion_dirs: for diffusion_dir in diffusion_dirs:
if diffusion_dir in out: if diffusion_dir in out:
out_model = db_config.pretrained_model_name_or_path out_model = db_config.pretrained_model_name_or_path
dest = os.path.join(db_config.pretrained_model_name_or_path,diffusion_dir) dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir)
if not dest: if not dest:
if ".ckpt" in out or ".safetensors" in out: if ".ckpt" in out or ".safetensors" in out:
dest = os.path.join(db_config.model_dir, "src") dest = os.path.join(db_config.model_dir, "src")
@ -1074,12 +1073,13 @@ def download_model(db_config: DreamboothConfig, token):
return out_model, config_file return out_model, config_file
def get_config_path( def get_config_path(
model_version: str = "v1", model_version: str = "v1",
train_type: str = "default", train_type: str = "default",
config_base_name: str = "training", config_base_name: str = "training",
prediction_type: str = "epsilon" prediction_type: str = "epsilon"
): ):
train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v" train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
parts = os.path.join( parts = os.path.join(
@ -1093,6 +1093,7 @@ def get_config_path(
) )
return os.path.abspath(parts) return os.path.abspath(parts)
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"): def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
config_base_name = "training" config_base_name = "training"
@ -1117,8 +1118,18 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type) return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_file: str, scheduler_type="ddim", from_hub=False, new_model_url="", def extract_checkpoint(
new_model_token="", extract_ema=False, train_unfrozen=False, is_512=True): ctx: ConversionContext,
new_model_name: str,
checkpoint_file: str,
scheduler_type="ddim",
from_hub=False,
new_model_url="",
new_model_token="",
extract_ema=False,
train_unfrozen=False,
is_512=True,
):
""" """
@param new_model_name: The name of the new model @param new_model_name: The name of the new model
@ -1153,9 +1164,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
msg = None msg = None
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""): if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
msg = "Please provide a URL and token for huggingface models." logger.warning("Please provide a URL and token for huggingface models.")
if msg is not None: return
return "", "", 0, 0, "", "", "", "", image_size, "", msg
# Create empty config # Create empty config
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
@ -1178,7 +1188,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
else: else:
msg = "Unable to fetch model from hub." msg = "Unable to fetch model from hub."
logger.warning(msg) logger.warning(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg return
try: try:
checkpoint = None checkpoint = None
@ -1194,6 +1204,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
logger.debug("Loading safetensors...") logger.debug("Loading safetensors...")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu") checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e: except Exception as e:
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
checkpoint = torch.jit.load(checkpoint_file) checkpoint = torch.jit.load(checkpoint_file)
else: else:
logger.debug("Loading ckpt...") logger.debug("Loading ckpt...")
@ -1231,10 +1242,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024: if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
logger.debug("UNet using v2 parameters.") logger.debug("UNet using v2 parameters.")
v2 = True v2 = True
except Exception as e:
except: logger.error("Exception loading unet!", traceback.format_exception(e))
logger.error("Exception loading unet!")
traceback.print_exc()
if v2 and not is_512: if v2 and not is_512:
prediction_type = "v_prediction" prediction_type = "v_prediction"
@ -1263,7 +1272,7 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if original_config_file is None or not os.path.exists(original_config_file): if original_config_file is None or not os.path.exists(original_config_file):
logger.warning("Unable to select a config file: %s" % (original_config_file)) logger.warning("Unable to select a config file: %s" % (original_config_file))
return "", "", 0, 0, "", "", "", "", image_size, "", "Unable to find a config file." return
logger.debug(f"Trying to load: {original_config_file}") logger.debug(f"Trying to load: {original_config_file}")
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
@ -1303,9 +1312,8 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
else: else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
logger.info("Converting UNet...")
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
logger.info("Converting UNet...")
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
@ -1317,16 +1325,16 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
db_config.save() db_config.save()
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
logger.info("Converting VAE...")
# Convert the VAE model. # Convert the VAE model.
logger.info("Converting VAE...")
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
logger.info("Converting text encoder...")
# Convert the text model. # Convert the text model.
logger.info("Converting text encoder...")
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder": if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint) text_model = convert_open_clip_checkpoint(checkpoint)
@ -1374,23 +1382,19 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler) scheduler=scheduler)
except Exception as e: except Exception as e:
logger.error(f"Exception setting up output: {e}") logger.error("Exception setting up output: %s", traceback.format_exception(e))
pipe = None pipe = None
traceback.print_exc()
if pipe is None or db_config is None: if pipe is None or db_config is None:
msg = "Pipeline or config is not set, unable to continue." msg = "Pipeline or config is not set, unable to continue."
logger.error(msg) logger.error(msg)
return "", "", 0, 0, "", "", "", "", image_size, "", msg return
else: else:
resolution = db_config.resolution
logger.info("Saving diffusion model...") logger.info("Saving diffusion model...")
pipe.save_pretrained(db_config.pretrained_model_name_or_path) pipe.save_pretrained(db_config.pretrained_model_name_or_path)
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}" result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
model_dir = db_config.model_dir
revision = db_config.revision revision = db_config.revision
scheduler = db_config.scheduler scheduler = db_config.scheduler
src = db_config.src
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file): if original_config_file is not None and os.path.exists(original_config_file):
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir) logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
@ -1417,11 +1421,11 @@ def extract_checkpoint(ctx: ConversionContext, new_model_name: str, checkpoint_f
if not os.path.exists(rem_dir): if not os.path.exists(rem_dir):
os.makedirs(rem_dir) os.makedirs(rem_dir)
logger.info(result_status) logger.info(result_status)
return return
def convert_diffusion_original( def convert_diffusion_original(
ctx: ConversionContext, ctx: ConversionContext,
model: ModelDict, model: ModelDict,

View File

@ -65,7 +65,9 @@ def convert_diffusion_stable(
dest_path = path.join(ctx.model_path, name) dest_path = path.join(ctx.model_path, name)
# diffusers go into a directory rather than .onnx file # diffusers go into a directory rather than .onnx file
logger.info("converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path) logger.info(
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
)
if single_vae: if single_vae:
logger.info("converting model with single VAE") logger.info("converting model with single VAE")

View File

@ -1,10 +1,8 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from shutil import copyfile
import torch import torch
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from torch.onnx import export from torch.onnx import export
from .utils import ConversionContext, ModelDict from .utils import ConversionContext, ModelDict

View File

@ -3,7 +3,7 @@ from functools import partial
from logging import getLogger from logging import getLogger
from os import path from os import path
from pathlib import Path from pathlib import Path
from typing import Dict, Union, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import requests import requests
import torch import torch
@ -44,7 +44,14 @@ def download_progress(urls: List[Tuple[str, str]]):
logger.info("Destination already exists: %s", dest_path) logger.info("Destination already exists: %s", dest_path)
return str(dest_path.absolute()) return str(dest_path.absolute())
req = requests.get(url, stream=True, allow_redirects=True) req = requests.get(
url,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200: if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RuntimeError( raise RuntimeError(
@ -117,6 +124,7 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
known_formats = ["onnx", "pth", "ckpt", "safetensors"] known_formats = ["onnx", "pth", "ckpt", "safetensors"]
def source_format(model: Dict) -> Optional[str]: def source_format(model: Dict) -> Optional[str]:
if "format" in model: if "format" in model:
return model["format"] return model["format"]