diff --git a/api/extras.json b/api/extras.json new file mode 100644 index 00000000..fb772bac --- /dev/null +++ b/api/extras.json @@ -0,0 +1,8 @@ +{ + "diffusion": [ + ["diffusion-knollingcase", "Aybeeceedee/knollingcase"], + ["diffusion-openjourney", "prompthero/openjourney"] + ], + "correction": [], + "upscaling": [] +} \ No newline at end of file diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 4834b5b7..96001e68 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -8,6 +8,7 @@ from diffusers import ( StableDiffusionPipeline, StableDiffusionUpscalePipeline, ) +from json import loads from logging import getLogger from onnx import load, save_model from os import environ, makedirs, mkdir, path @@ -60,16 +61,6 @@ base_models: Models = { ], } -# other neat models -extra_models: Models = { - 'diffusion': [ - ('diffusion-knollingcase', 'Aybeeceedee/knollingcase'), - ('diffusion-openjourney', 'prompthero/openjourney'), - ], - 'correction': [], - 'upscaling': [], -} - model_path = environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models')) training_device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -491,8 +482,10 @@ def main() -> int: # model groups parser.add_argument('--correction', action='store_true', default=False) parser.add_argument('--diffusion', action='store_true', default=False) - parser.add_argument('--extras', action='store_true', default=False) parser.add_argument('--upscaling', action='store_true', default=False) + + # extra models + parser.add_argument('--extras', nargs='*', type=str, default=[]) parser.add_argument('--skip', nargs='*', type=str, default=[]) # export options @@ -500,7 +493,7 @@ def main() -> int: '--half', action='store_true', default=False, - help='Export models for half precision, faster on some Nvidia cards' + help='Export models for half precision, faster on some Nvidia cards.' ) parser.add_argument( '--opset', @@ -524,9 +517,15 @@ def main() -> int: logger.info('Converting base models.') load_models(args, base_models) - if args.extras: - logger.info('Converting extra models.') - load_models(args, extra_models) + for file in args.extras: + logger.info('Loading extra models from %s', file) + try: + with open(file, 'r') as f: + data = loads(f.read()) + logger.info('Converting extra models.') + load_models(args, data) + except Exception as err: + logger.error('Error converting extra models: %s', err) return 0