1
0
Fork 0

lint(api): move half-precision CUDA check before models, apply lint

This commit is contained in:
Sean Sube 2023-02-11 13:31:34 -06:00
parent 694d15547f
commit 454abcdddc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 21 additions and 13 deletions

View File

@ -227,6 +227,11 @@ def main() -> int:
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token) ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if ctx.half and ctx.training_device != "cuda":
raise ValueError(
"Half precision model export is only supported on GPUs with CUDA"
)
if not path.exists(ctx.model_path): if not path.exists(ctx.model_path):
logger.info("Model path does not existing, creating: %s", ctx.model_path) logger.info("Model path does not existing, creating: %s", ctx.model_path)
makedirs(ctx.model_path) makedirs(ctx.model_path)

View File

@ -41,7 +41,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig, LDMBertConfig,
LDMBertModel, LDMBertModel,
) )
import sys
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
@ -54,7 +53,7 @@ from transformers import (
) )
from .diffusion_stable import convert_diffusion_stable from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name
logger = getLogger(__name__) logger = getLogger(__name__)
@ -1018,7 +1017,16 @@ def download_model(db_config: TrainingConfig, token):
if safe_model and bin_model: if safe_model and bin_model:
diffusion_files.remove(bin_model) diffusion_files.remove(bin_model)
model_file = next((x for x in model_files if ".safetensors" in x and "nonema" in x), next((x for x in model_files if "nonema" in x), next((x for x in model_files if ".safetensors" in x), model_files[0] if model_files else None))) model_file = next(
(x for x in model_files if ".safetensors" in x and "nonema" in x),
next(
(x for x in model_files if "nonema" in x),
next(
(x for x in model_files if ".safetensors" in x),
model_files[0] if model_files else None
)
)
)
files_to_fetch = None files_to_fetch = None
@ -1167,7 +1175,7 @@ def extract_checkpoint(
# Create empty config # Create empty config
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
src=checkpoint_file if not from_hub else new_model_url) src=checkpoint_file if not from_hub else new_model_url)
original_config_file = None original_config_file = None
@ -1240,7 +1248,7 @@ def extract_checkpoint(
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024: if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
logger.debug("UNet using v2 parameters.") logger.debug("UNet using v2 parameters.")
v2 = True v2 = True
except Exception as e: except Exception:
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info())) logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
if v2 and not is_512: if v2 and not is_512:
@ -1379,7 +1387,7 @@ def extract_checkpoint(
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler) scheduler=scheduler)
except Exception as e: except Exception:
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info())) logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
pipe = None pipe = None

View File

@ -76,11 +76,6 @@ def convert_diffusion_stable(
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping.")
return return
if ctx.half and ctx.training_device != "cuda":
raise ValueError(
"Half precision model export is only supported on GPUs with CUDA"
)
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
source, source,
torch_dtype=dtype, torch_dtype=dtype,

View File

@ -145,7 +145,6 @@ def source_format(model: Dict) -> Optional[str]:
return None return None
class Config(object): class Config(object):
def __init__(self, kwargs): def __init__(self, kwargs):
self.__dict__.update(kwargs) self.__dict__.update(kwargs)
@ -165,7 +164,6 @@ class Config(object):
setattr(target, k, v) setattr(target, k, v)
def load_yaml(file: str) -> str: def load_yaml(file: str) -> str:
with open(file, "r") as f: with open(file, "r") as f:
data = safe_load(f.read()) data = safe_load(f.read())
@ -173,5 +171,7 @@ def load_yaml(file: str) -> str:
safe_chars = "._-" safe_chars = "._-"
def sanitize_name(name): def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in safe_chars)) return "".join(x for x in name if (x.isalnum() or x in safe_chars))