lint(api): move half-precision CUDA check before models, apply lint
This commit is contained in:
parent
694d15547f
commit
454abcdddc
|
@ -227,6 +227,11 @@ def main() -> int:
|
||||||
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
|
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
|
||||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||||
|
|
||||||
|
if ctx.half and ctx.training_device != "cuda":
|
||||||
|
raise ValueError(
|
||||||
|
"Half precision model export is only supported on GPUs with CUDA"
|
||||||
|
)
|
||||||
|
|
||||||
if not path.exists(ctx.model_path):
|
if not path.exists(ctx.model_path):
|
||||||
logger.info("Model path does not existing, creating: %s", ctx.model_path)
|
logger.info("Model path does not existing, creating: %s", ctx.model_path)
|
||||||
makedirs(ctx.model_path)
|
makedirs(ctx.model_path)
|
||||||
|
|
|
@ -41,7 +41,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||||
LDMBertConfig,
|
LDMBertConfig,
|
||||||
LDMBertModel,
|
LDMBertModel,
|
||||||
)
|
)
|
||||||
import sys
|
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
@ -54,7 +53,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
from .diffusion_stable import convert_diffusion_stable
|
||||||
from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml
|
from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -1018,7 +1017,16 @@ def download_model(db_config: TrainingConfig, token):
|
||||||
if safe_model and bin_model:
|
if safe_model and bin_model:
|
||||||
diffusion_files.remove(bin_model)
|
diffusion_files.remove(bin_model)
|
||||||
|
|
||||||
model_file = next((x for x in model_files if ".safetensors" in x and "nonema" in x), next((x for x in model_files if "nonema" in x), next((x for x in model_files if ".safetensors" in x), model_files[0] if model_files else None)))
|
model_file = next(
|
||||||
|
(x for x in model_files if ".safetensors" in x and "nonema" in x),
|
||||||
|
next(
|
||||||
|
(x for x in model_files if "nonema" in x),
|
||||||
|
next(
|
||||||
|
(x for x in model_files if ".safetensors" in x),
|
||||||
|
model_files[0] if model_files else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
files_to_fetch = None
|
files_to_fetch = None
|
||||||
|
|
||||||
|
@ -1240,7 +1248,7 @@ def extract_checkpoint(
|
||||||
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
if key_name in unet_dict and unet_dict[key_name].shape[-1] == 1024:
|
||||||
logger.debug("UNet using v2 parameters.")
|
logger.debug("UNet using v2 parameters.")
|
||||||
v2 = True
|
v2 = True
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
|
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
|
||||||
|
|
||||||
if v2 and not is_512:
|
if v2 and not is_512:
|
||||||
|
@ -1379,7 +1387,7 @@ def extract_checkpoint(
|
||||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||||
scheduler=scheduler)
|
scheduler=scheduler)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
|
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
|
||||||
pipe = None
|
pipe = None
|
||||||
|
|
||||||
|
|
|
@ -76,11 +76,6 @@ def convert_diffusion_stable(
|
||||||
logger.info("ONNX model already exists, skipping.")
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if ctx.half and ctx.training_device != "cuda":
|
|
||||||
raise ValueError(
|
|
||||||
"Half precision model export is only supported on GPUs with CUDA"
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
source,
|
source,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
|
|
|
@ -145,7 +145,6 @@ def source_format(model: Dict) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Config(object):
|
class Config(object):
|
||||||
def __init__(self, kwargs):
|
def __init__(self, kwargs):
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
|
@ -165,7 +164,6 @@ class Config(object):
|
||||||
setattr(target, k, v)
|
setattr(target, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(file: str) -> str:
|
def load_yaml(file: str) -> str:
|
||||||
with open(file, "r") as f:
|
with open(file, "r") as f:
|
||||||
data = safe_load(f.read())
|
data = safe_load(f.read())
|
||||||
|
@ -173,5 +171,7 @@ def load_yaml(file: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
safe_chars = "._-"
|
safe_chars = "._-"
|
||||||
|
|
||||||
|
|
||||||
def sanitize_name(name):
|
def sanitize_name(name):
|
||||||
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
|
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
|
||||||
|
|
Loading…
Reference in New Issue