1
0
Fork 0

lint(api): lowercase log messages

This commit is contained in:
Sean Sube 2023-02-16 18:42:05 -06:00
parent 0ed4af18ad
commit 5e9dfa3465
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 57 additions and 57 deletions

View File

@ -142,7 +142,7 @@ def fetch_model(
for p in [model_path, model_onnx]: for p in [model_path, model_onnx]:
if path.exists(p): if path.exists(p):
logger.debug("Model already exists, skipping fetch.") logger.debug("model already exists, skipping fetch")
return p return p
# add an extension if possible, some of the conversion code checks for it # add an extension if possible, some of the conversion code checks for it
@ -160,26 +160,26 @@ def fetch_model(
if source.startswith(proto): if source.startswith(proto):
api_source = api_root % (remove_prefix(source, proto)) api_source = api_root % (remove_prefix(source, proto))
logger.info( logger.info(
"Downloading model from %s: %s -> %s", api_name, api_source, cache_name "downloading model from %s: %s -> %s", api_name, api_source, cache_name
) )
return download_progress([(api_source, cache_name)]) return download_progress([(api_source, cache_name)])
if source.startswith(model_source_huggingface): if source.startswith(model_source_huggingface):
hub_source = remove_prefix(source, model_source_huggingface) hub_source = remove_prefix(source, model_source_huggingface)
logger.info("Downloading model from Huggingface Hub: %s", hub_source) logger.info("downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not # from_pretrained has a bunch of useful logic that snapshot_download by itself down not
return hub_source return hub_source
elif source.startswith("https://"): elif source.startswith("https://"):
logger.info("Downloading model from: %s", source) logger.info("downloading model from: %s", source)
return download_progress([(source, cache_name)]) return download_progress([(source, cache_name)])
elif source.startswith("http://"): elif source.startswith("http://"):
logger.warning("Downloading model from insecure source: %s", source) logger.warning("downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)]) return download_progress([(source, cache_name)])
elif source.startswith(path.sep) or source.startswith("."): elif source.startswith(path.sep) or source.startswith("."):
logger.info("Using local model: %s", source) logger.info("using local model: %s", source)
return source return source
else: else:
logger.info("Unknown model location, using path as provided: %s", source) logger.info("unknown model location, using path as provided: %s", source)
return source return source
@ -190,12 +190,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
name = model.get("name") name = model.get("name")
if name in args.skip: if name in args.skip:
logger.info("Skipping source: %s", name) logger.info("skipping source: %s", name)
else: else:
model_format = source_format(model) model_format = source_format(model)
source = model["source"] source = model["source"]
dest = fetch_model(ctx, name, source, model_format=model_format) dest = fetch_model(ctx, name, source, model_format=model_format)
logger.info("Finished downloading source: %s -> %s", source, dest) logger.info("finished downloading source: %s -> %s", source, dest)
if args.diffusion and "diffusion" in models: if args.diffusion and "diffusion" in models:
for model in models.get("diffusion"): for model in models.get("diffusion"):
@ -203,7 +203,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
name = model.get("name") name = model.get("name")
if name in args.skip: if name in args.skip:
logger.info("Skipping model: %s", name) logger.info("skipping model: %s", name)
else: else:
model_format = source_format(model) model_format = source_format(model)
source = fetch_model( source = fetch_model(
@ -229,7 +229,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
name = model.get("name") name = model.get("name")
if name in args.skip: if name in args.skip:
logger.info("Skipping model: %s", name) logger.info("skipping model: %s", name)
else: else:
model_format = source_format(model) model_format = source_format(model)
source = fetch_model( source = fetch_model(
@ -243,7 +243,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
name = model.get("name") name = model.get("name")
if name in args.skip: if name in args.skip:
logger.info("Skipping model: %s", name) logger.info("skipping model: %s", name)
else: else:
model_format = source_format(model) model_format = source_format(model)
source = fetch_model( source = fetch_model(
@ -290,23 +290,23 @@ def main() -> int:
logger.info("CLI arguments: %s", args) logger.info("CLI arguments: %s", args)
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token) ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device)
if ctx.half and ctx.training_device != "cuda": if ctx.half and ctx.training_device != "cuda":
raise ValueError( raise ValueError(
"Half precision model export is only supported on GPUs with CUDA" "half precision model export is only supported on GPUs with CUDA"
) )
if not path.exists(ctx.model_path): if not path.exists(ctx.model_path):
logger.info("Model path does not existing, creating: %s", ctx.model_path) logger.info("model path does not existing, creating: %s", ctx.model_path)
makedirs(ctx.model_path) makedirs(ctx.model_path)
logger.info("Converting base models.") logger.info("converting base models")
convert_models(ctx, args, base_models) convert_models(ctx, args, base_models)
for file in args.extras: for file in args.extras:
if file is not None and file != "": if file is not None and file != "":
logger.info("Loading extra models from %s", file) logger.info("loading extra models from %s", file)
try: try:
with open(file, "r") as f: with open(file, "r") as f:
data = safe_load(f.read()) data = safe_load(f.read())
@ -318,12 +318,12 @@ def main() -> int:
try: try:
validate(data, schema) validate(data, schema)
logger.info("Converting extra models.") logger.info("converting extra models")
convert_models(ctx, args, data) convert_models(ctx, args, data)
except ValidationError as err: except ValidationError as err:
logger.error("Invalid data in extras file: %s", err) logger.error("invalid data in extras file: %s", err)
except Exception as err: except Exception as err:
logger.error("Error converting extra models: %s", err) logger.error("error converting extra models: %s", err)
return 0 return 0

View File

@ -24,7 +24,7 @@ def convert_correction_gfpgan(
logger.info("converting GFPGAN model: %s -> %s", name, dest) logger.info("converting GFPGAN model: %s -> %s", name, dest)
if path.isfile(dest): if path.isfile(dest):
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping")
return return
logger.info("loading and training model") logger.info("loading and training model")
@ -66,4 +66,4 @@ def convert_correction_gfpgan(
opset_version=ctx.opset, opset_version=ctx.opset,
export_params=True, export_params=True,
) )
logger.info("GFPGAN exported to ONNX successfully.") logger.info("GFPGAN exported to ONNX successfully")

View File

@ -907,7 +907,7 @@ def convert_open_clip_checkpoint(checkpoint):
if 'cond_stage_model.model.text_projection' in checkpoint: if 'cond_stage_model.model.text_projection' in checkpoint:
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0]) d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
else: else:
logger.debug("No projection shape found, setting to 1024") logger.debug("no projection shape found, setting to 1024")
d_model = 1024 d_model = 1024
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
@ -962,7 +962,7 @@ def replace_symlinks(path, base):
blob_path = None blob_path = None
if blob_path is None: if blob_path is None:
logger.debug("NO BLOB") logger.debug("no blob")
return return
os.replace(blob_path, path) os.replace(blob_path, path)
elif os.path.isdir(path): elif os.path.isdir(path):
@ -985,7 +985,7 @@ def download_model(db_config: TrainingConfig, token):
) )
if repo_info.sha is None: if repo_info.sha is None:
logger.warning("Unable to fetch repo info: %s", hub_url) logger.warning("unable to fetch repo info: %s", hub_url)
return None, None return None, None
siblings = repo_info.siblings siblings = repo_info.siblings
@ -1049,7 +1049,7 @@ def download_model(db_config: TrainingConfig, token):
logger.info(f"Fetching files: {files_to_fetch}") logger.info(f"Fetching files: {files_to_fetch}")
if not len(files_to_fetch): if not len(files_to_fetch):
logger.debug("Nothing to fetch!") logger.debug("nothing to fetch")
return None, None return None, None
mytqdm = huggingface_hub.utils.tqdm.tqdm mytqdm = huggingface_hub.utils.tqdm.tqdm
@ -1190,18 +1190,18 @@ def extract_checkpoint(
map_location = torch.device("cpu") map_location = torch.device("cpu")
# Try to determine if v1 or v2 model if we have a ckpt # Try to determine if v1 or v2 model if we have a ckpt
logger.info("Loading model from checkpoint.") logger.info("loading model from checkpoint")
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
os.environ["SAFETENSORS_FAST_GPU"] = "1" os.environ["SAFETENSORS_FAST_GPU"] = "1"
try: try:
logger.debug("Loading safetensors...") logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu") checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e: except Exception as e:
logger.warn("Failed to load as safetensors file, falling back to torch...", e) logger.warn("Failed to load as safetensors file, falling back to torch...", e)
checkpoint = torch.jit.load(checkpoint_file) checkpoint = torch.jit.load(checkpoint_file)
else: else:
logger.debug("Loading ckpt...") logger.debug("loading ckpt")
checkpoint = torch.load(checkpoint_file, map_location=map_location) checkpoint = torch.load(checkpoint_file, map_location=map_location)
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
@ -1221,7 +1221,7 @@ def extract_checkpoint(
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not is_512: if not is_512:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
logger.debug("Setting upcast_attention") logger.debug("setting upcast_attention")
upcast_attention = True upcast_attention = True
v2 = True v2 = True
else: else:
@ -1249,10 +1249,10 @@ def extract_checkpoint(
original_config_file = config_check original_config_file = config_check
if original_config_file is None or not os.path.exists(original_config_file): if original_config_file is None or not os.path.exists(original_config_file):
logger.warning("Unable to select a config file: %s" % (original_config_file)) logger.warning("unable to select a config file: %s" % (original_config_file))
return return
logger.debug(f"Trying to load: {original_config_file}") logger.debug("trying to load: %s", original_config_file)
original_config = load_yaml(original_config_file) original_config = load_yaml(original_config_file)
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config.model.params.timesteps
@ -1291,7 +1291,7 @@ def extract_checkpoint(
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
logger.info("Converting UNet...") logger.info("converting UNet")
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
@ -1304,7 +1304,7 @@ def extract_checkpoint(
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model. # Convert the VAE model.
logger.info("Converting VAE...") logger.info("converting VAE")
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
@ -1312,7 +1312,7 @@ def extract_checkpoint(
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
# Convert the text model. # Convert the text model.
logger.info("Converting text encoder...") logger.info("converting text encoder")
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenOpenCLIPEmbedder": if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint) text_model = convert_open_clip_checkpoint(checkpoint)
@ -1360,15 +1360,15 @@ def extract_checkpoint(
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler) scheduler=scheduler)
except Exception: except Exception:
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info())) logger.error("exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
pipe = None pipe = None
if pipe is None or db_config is None: if pipe is None or db_config is None:
msg = "Pipeline or config is not set, unable to continue." msg = "pipeline or config is not set, unable to continue."
logger.error(msg) logger.error(msg)
return return
else: else:
logger.info("Saving diffusion model...") logger.info("saving diffusion model")
pipe.save_pretrained(db_config.pretrained_model_name_or_path) pipe.save_pretrained(db_config.pretrained_model_name_or_path)
result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}" result_status = f"Checkpoint successfully extracted to {db_config.pretrained_model_name_or_path}"
revision = db_config.revision revision = db_config.revision
@ -1413,10 +1413,10 @@ def convert_diffusion_original(
source = source or model["source"] source = source or model["source"]
dest = os.path.join(ctx.model_path, name) dest = os.path.join(ctx.model_path, name)
logger.info("Converting original Diffusers checkpoint %s: %s -> %s", name, source, dest) logger.info("converting original Diffusers checkpoint %s: %s -> %s", name, source, dest)
if os.path.exists(dest): if os.path.exists(dest):
logger.info("ONNX pipeline already exists, skipping.") logger.info("ONNX pipeline already exists, skipping")
return return
torch_name = name + "-torch" torch_name = name + "-torch"
@ -1424,11 +1424,11 @@ def convert_diffusion_original(
working_name = os.path.join(ctx.cache_path, torch_name, "working") working_name = os.path.join(ctx.cache_path, torch_name, "working")
if os.path.exists(torch_path): if os.path.exists(torch_path):
logger.info("Torch pipeline already exists, reusing: %s", torch_path) logger.info("torch pipeline already exists, reusing: %s", torch_path)
else: else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path) logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config")) extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"))
logger.info("Converted original Diffusers checkpoint to Torch model.") logger.info("converted original Diffusers checkpoint to Torch model")
convert_diffusion_stable(ctx, model, working_name) convert_diffusion_stable(ctx, model, working_name)
logger.info("ONNX pipeline saved to %s", name) logger.info("ONNX pipeline saved to %s", name)

View File

@ -84,7 +84,7 @@ def convert_diffusion_stable(
logger.info("converting model with single VAE") logger.info("converting model with single VAE")
if path.exists(dest_path): if path.exists(dest_path):
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping")
return return
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(

View File

@ -24,7 +24,7 @@ def convert_upscale_resrgan(
logger.info("converting Real ESRGAN model: %s -> %s", name, dest) logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
if path.isfile(dest): if path.isfile(dest):
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping")
return return
logger.info("loading and training model") logger.info("loading and training model")
@ -65,4 +65,4 @@ def convert_upscale_resrgan(
opset_version=ctx.opset, opset_version=ctx.opset,
export_params=True, export_params=True,
) )
logger.info("Real ESRGAN exported to ONNX successfully.") logger.info("real ESRGAN exported to ONNX successfully")

View File

@ -49,7 +49,7 @@ def download_progress(urls: List[Tuple[str, str]]):
dest_path.parent.mkdir(parents=True, exist_ok=True) dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists(): if dest_path.exists():
logger.debug("Destination already exists: %s", dest_path) logger.debug("destination already exists: %s", dest_path)
return str(dest_path.absolute()) return str(dest_path.absolute())
req = requests.get( req = requests.get(

View File

@ -41,7 +41,7 @@ def unload(exclude):
to_unload.append(mod) to_unload.append(mod)
break break
logger.debug("Unloading modules for patching: %s", to_unload) logger.debug("unloading modules for patching: %s", to_unload)
for mod in to_unload: for mod in to_unload:
del sys.modules[mod] del sys.modules[mod]
@ -126,28 +126,28 @@ def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str:
cache_path = path.basename(parsed.path) cache_path = path.basename(parsed.path)
cache_path = path.join(ctx.cache_path, cache_path) cache_path = path.join(ctx.cache_path, cache_path)
logger.debug("Patching download path: %s -> %s", url, cache_path) logger.debug("patching download path: %s -> %s", url, cache_path)
if path.exists(cache_path): if path.exists(cache_path):
return cache_path return cache_path
else: else:
raise FileNotFoundError("Missing cache file: %s" % (cache_path)) raise FileNotFoundError("missing cache file: %s" % (cache_path))
def apply_patch_basicsr(ctx: ServerContext): def apply_patch_basicsr(ctx: ServerContext):
logger.debug("Patching BasicSR module...") logger.debug("patching BasicSR module")
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx) basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, ctx)
def apply_patch_codeformer(ctx: ServerContext): def apply_patch_codeformer(ctx: ServerContext):
logger.debug("Patching CodeFormer module...") logger.debug("patching CodeFormer module")
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx) codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, ctx)
def apply_patch_facexlib(ctx: ServerContext): def apply_patch_facexlib(ctx: ServerContext):
logger.debug("Patching Facexlib module...") logger.debug("patching Facexlib module")
facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx) facexlib.utils.load_file_from_url = partial(patch_cache_path, ctx)

View File

@ -26,18 +26,18 @@ class ModelCache:
def set(self, tag: str, key: Any, value: Any) -> None: def set(self, tag: str, key: Any, value: Any) -> None:
if self.limit == 0: if self.limit == 0:
logger.debug("Cache limit set to 0, not caching model: %s", tag) logger.debug("cache limit set to 0, not caching model: %s", tag)
return return
for i in range(len(self.cache)): for i in range(len(self.cache)):
t, k, v = self.cache[i] t, k, v = self.cache[i]
if tag == t: if tag == t:
if key != k: if key != k:
logger.debug("Updating model cache: %s", tag) logger.debug("updating model cache: %s", tag)
self.cache[i] = (tag, key, value) self.cache[i] = (tag, key, value)
return return
logger.debug("Adding new model to cache: %s", tag) logger.debug("adding new model to cache: %s", tag)
self.cache.append((tag, key, value)) self.cache.append((tag, key, value))
self.prune() self.prune()
@ -49,4 +49,4 @@ class ModelCache:
) )
self.cache[:] = self.cache[-self.limit :] self.cache[:] = self.cache[-self.limit :]
else: else:
logger.debug("Model cache below limit, %s of %s", total, self.limit) logger.debug("model cache below limit, %s of %s", total, self.limit)