1
0
Fork 0

lint(api): move half-precision CUDA check before models, apply lint

This commit is contained in:
Sean Sube 2023-02-11 13:31:34 -06:00
parent 694d15547f
commit 454abcdddc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 21 additions and 13 deletions

View File

@ -227,6 +227,11 @@ def main() -> int:
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if ctx.half and ctx.training_device != "cuda":
raise ValueError(
"Half precision model export is only supported on GPUs with CUDA"
)
if not path.exists(ctx.model_path):
logger.info("Model path does not existing, creating: %s", ctx.model_path)
makedirs(ctx.model_path)

View File

@ -41,7 +41,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig,
LDMBertModel,
)
import sys
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download
@ -54,7 +53,7 @@ from transformers import (
)
from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml
from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name
logger = getLogger(__name__)
@ -1018,7 +1017,16 @@ def download_model(db_config: TrainingConfig, token):
if safe_model and bin_model:
diffusion_files.remove(bin_model)
model_file = next((x for x in model_files if ".safetensors" in x and "nonema" in x), next((x for x in model_files if "nonema" in x), next((x for x in model_files if ".safetensors" in x), model_files[0] if model_files else None)))
model_file = next(
(x for x in model_files if ".safetensors" in x and "nonema" in x),
next(
(x for x in model_files if "nonema" in x),
next(
(x for x in model_files if ".safetensors" in x),
model_files[0] if model_files else None
)
)
)
files_to_fetch = None
@ -1167,7 +1175,7 @@ def extract_checkpoint(
# Create empty config
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
src=checkpoint_file if not from_hub else new_model_url)
src=checkpoint_file if not from_hub else new_model_url)
original_config_file = None
@ -1240,7 +1248,7 @@ def extract_checkpoint(
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
logger.debug("UNet using v2 parameters.")
v2 = True
except Exception as e:
except Exception:
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
if v2 and not is_512:
@ -1379,7 +1387,7 @@ def extract_checkpoint(
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler)
except Exception as e:
except Exception:
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
pipe = None

View File

@ -76,11 +76,6 @@ def convert_diffusion_stable(
logger.info("ONNX model already exists, skipping.")
return
if ctx.half and ctx.training_device != "cuda":
raise ValueError(
"Half precision model export is only supported on GPUs with CUDA"
)
pipeline = StableDiffusionPipeline.from_pretrained(
source,
torch_dtype=dtype,

View File

@ -145,7 +145,6 @@ def source_format(model: Dict) -> Optional[str]:
return None
class Config(object):
def __init__(self, kwargs):
self.__dict__.update(kwargs)
@ -165,7 +164,6 @@ class Config(object):
setattr(target, k, v)
def load_yaml(file: str) -> str:
with open(file, "r") as f:
data = safe_load(f.read())
@ -173,5 +171,7 @@ def load_yaml(file: str) -> str:
safe_chars = "._-"
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in safe_chars))