1
0
Fork 0

collect and summarize errors while converting models

This commit is contained in:
Sean Sube 2023-04-29 22:58:58 -05:00
parent bc71583393
commit 1c308e0bab
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 12 additions and 0 deletions

View File

@ -262,6 +262,8 @@ def fetch_model(
def convert_models(conversion: ConversionContext, args, models: Models): def convert_models(conversion: ConversionContext, args, models: Models):
model_errors = []
if args.sources and "sources" in models: if args.sources and "sources" in models:
for model in models.get("sources"): for model in models.get("sources"):
model = tuple_to_source(model) model = tuple_to_source(model)
@ -284,6 +286,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.info("finished downloading source: %s -> %s", source, dest) logger.info("finished downloading source: %s -> %s", source, dest)
except Exception: except Exception:
logger.exception("error fetching source %s", name) logger.exception("error fetching source %s", name)
model_errors.append(name)
if args.networks and "networks" in models: if args.networks and "networks" in models:
for network in models.get("networks"): for network in models.get("networks"):
@ -333,6 +336,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.info("finished downloading network: %s -> %s", source, dest) logger.info("finished downloading network: %s -> %s", source, dest)
except Exception: except Exception:
logger.exception("error fetching network %s", name) logger.exception("error fetching network %s", name)
model_errors.append(name)
if args.diffusion and "diffusion" in models: if args.diffusion and "diffusion" in models:
for model in models.get("diffusion"): for model in models.get("diffusion"):
@ -471,6 +475,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
"error converting diffusion model %s", "error converting diffusion model %s",
name, name,
) )
model_errors.append(name)
if args.upscaling and "upscaling" in models: if args.upscaling and "upscaling" in models:
for model in models.get("upscaling"): for model in models.get("upscaling"):
@ -497,11 +502,13 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.error( logger.error(
"unknown upscaling model type %s for %s", model_type, name "unknown upscaling model type %s for %s", model_type, name
) )
model_errors.append(name)
except Exception: except Exception:
logger.exception( logger.exception(
"error converting upscaling model %s", "error converting upscaling model %s",
name, name,
) )
model_errors.append(name)
if args.correction and "correction" in models: if args.correction and "correction" in models:
for model in models.get("correction"): for model in models.get("correction"):
@ -523,11 +530,16 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.error( logger.error(
"unknown correction model type %s for %s", model_type, name "unknown correction model type %s for %s", model_type, name
) )
model_errors.append(name)
except Exception: except Exception:
logger.exception( logger.exception(
"error converting correction model %s", "error converting correction model %s",
name, name,
) )
model_errors.append(name)
if len(model_errors) > 0:
logger.error("error while converting models: %s", model_errors)
def main() -> int: def main() -> int: