1
0
Fork 0

feat(api): provide a way for users to add models to the convert list (#70)

This commit is contained in:
Sean Sube 2023-02-02 08:59:47 -06:00
parent 0050cea694
commit c837830043
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 22 additions and 15 deletions

8
api/extras.json Normal file
View File

@ -0,0 +1,8 @@
{
"diffusion": [
["diffusion-knollingcase", "Aybeeceedee/knollingcase"],
["diffusion-openjourney", "prompthero/openjourney"]
],
"correction": [],
"upscaling": []
}

View File

@ -8,6 +8,7 @@ from diffusers import (
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from json import loads
from logging import getLogger
from onnx import load, save_model
from os import environ, makedirs, mkdir, path
@ -60,16 +61,6 @@ base_models: Models = {
],
}
# other neat models
extra_models: Models = {
'diffusion': [
('diffusion-knollingcase', 'Aybeeceedee/knollingcase'),
('diffusion-openjourney', 'prompthero/openjourney'),
],
'correction': [],
'upscaling': [],
}
model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -491,8 +482,10 @@ def main() -> int:
# model groups
parser.add_argument('--correction', action='store_true', default=False)
parser.add_argument('--diffusion', action='store_true', default=False)
parser.add_argument('--extras', action='store_true', default=False)
parser.add_argument('--upscaling', action='store_true', default=False)
# extra models
parser.add_argument('--extras', nargs='*', type=str, default=[])
parser.add_argument('--skip', nargs='*', type=str, default=[])
# export options
@ -500,7 +493,7 @@ def main() -> int:
'--half',
action='store_true',
default=False,
help='Export models for half precision, faster on some Nvidia cards'
help='Export models for half precision, faster on some Nvidia cards.'
)
parser.add_argument(
'--opset',
@ -524,9 +517,15 @@ def main() -> int:
logger.info('Converting base models.')
load_models(args, base_models)
if args.extras:
logger.info('Converting extra models.')
load_models(args, extra_models)
for file in args.extras:
logger.info('Loading extra models from %s', file)
try:
with open(file, 'r') as f:
data = loads(f.read())
logger.info('Converting extra models.')
load_models(args, data)
except Exception as err:
logger.error('Error converting extra models: %s', err)
return 0