diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 2c4eeb66..a11d706f 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -227,6 +227,11 @@ def main() -> int: ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) + if ctx.half and ctx.training_device != "cuda": + raise ValueError( + "Half precision model export is only supported on GPUs with CUDA" + ) + if not path.exists(ctx.model_path): logger.info("Model path does not existing, creating: %s", ctx.model_path) makedirs(ctx.model_path) diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 292be414..63a6575a 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -41,7 +41,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import ( LDMBertConfig, LDMBertModel, ) -import sys from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfApi, hf_hub_download @@ -54,7 +53,7 @@ from transformers import ( ) from .diffusion_stable import convert_diffusion_stable -from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml +from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name logger = getLogger(__name__) @@ -1018,7 +1017,16 @@ def download_model(db_config: TrainingConfig, token): if safe_model and bin_model: diffusion_files.remove(bin_model) - model_file = next((x for x in model_files if ".safetensors" in x and "nonema" in x), next((x for x in model_files if "nonema" in x), next((x for x in model_files if ".safetensors" in x), model_files[0] if model_files else None))) + model_file = next( + (x for x in model_files if ".safetensors" in x and "nonema" in x), + next( + (x for x in model_files if "nonema" in x), + next( + (x for x in model_files if ".safetensors" in x), + model_files[0] if model_files else None + ) + ) + ) files_to_fetch = None @@ -1167,7 +1175,7 @@ def extract_checkpoint( # Create empty config db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, - src=checkpoint_file if not from_hub else new_model_url) + src=checkpoint_file if not from_hub else new_model_url) original_config_file = None @@ -1240,7 +1248,7 @@ def extract_checkpoint( if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024: logger.debug("UNet using v2 parameters.") v2 = True - except Exception as e: + except Exception: logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info())) if v2 and not is_512: @@ -1379,7 +1387,7 @@ def extract_checkpoint( tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - except Exception as e: + except Exception: logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info())) pipe = None diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index 9afcddf9..e1708f98 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -76,11 +76,6 @@ def convert_diffusion_stable( logger.info("ONNX model already exists, skipping.") return - if ctx.half and ctx.training_device != "cuda": - raise ValueError( - "Half precision model export is only supported on GPUs with CUDA" - ) - pipeline = StableDiffusionPipeline.from_pretrained( source, torch_dtype=dtype, diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index cedcd96b..d36e3e2c 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -145,7 +145,6 @@ def source_format(model: Dict) -> Optional[str]: return None - class Config(object): def __init__(self, kwargs): self.__dict__.update(kwargs) @@ -165,7 +164,6 @@ class Config(object): setattr(target, k, v) - def load_yaml(file: str) -> str: with open(file, "r") as f: data = safe_load(f.read()) @@ -173,5 +171,7 @@ def load_yaml(file: str) -> str: safe_chars = "._-" + + def sanitize_name(name): return "".join(x for x in name if (x.isalnum() or x in safe_chars))