feat(api): provide a way for users to add models to the convert list (#70)
This commit is contained in:
parent
0050cea694
commit
c837830043
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"diffusion": [
|
||||
["diffusion-knollingcase", "Aybeeceedee/knollingcase"],
|
||||
["diffusion-openjourney", "prompthero/openjourney"]
|
||||
],
|
||||
"correction": [],
|
||||
"upscaling": []
|
||||
}
|
|
@ -8,6 +8,7 @@ from diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from json import loads
|
||||
from logging import getLogger
|
||||
from onnx import load, save_model
|
||||
from os import environ, makedirs, mkdir, path
|
||||
|
@ -60,16 +61,6 @@ base_models: Models = {
|
|||
],
|
||||
}
|
||||
|
||||
# other neat models
|
||||
extra_models: Models = {
|
||||
'diffusion': [
|
||||
('diffusion-knollingcase', 'Aybeeceedee/knollingcase'),
|
||||
('diffusion-openjourney', 'prompthero/openjourney'),
|
||||
],
|
||||
'correction': [],
|
||||
'upscaling': [],
|
||||
}
|
||||
|
||||
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
||||
path.join('..', 'models'))
|
||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
@ -491,8 +482,10 @@ def main() -> int:
|
|||
# model groups
|
||||
parser.add_argument('--correction', action='store_true', default=False)
|
||||
parser.add_argument('--diffusion', action='store_true', default=False)
|
||||
parser.add_argument('--extras', action='store_true', default=False)
|
||||
parser.add_argument('--upscaling', action='store_true', default=False)
|
||||
|
||||
# extra models
|
||||
parser.add_argument('--extras', nargs='*', type=str, default=[])
|
||||
parser.add_argument('--skip', nargs='*', type=str, default=[])
|
||||
|
||||
# export options
|
||||
|
@ -500,7 +493,7 @@ def main() -> int:
|
|||
'--half',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Export models for half precision, faster on some Nvidia cards'
|
||||
help='Export models for half precision, faster on some Nvidia cards.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--opset',
|
||||
|
@ -524,9 +517,15 @@ def main() -> int:
|
|||
logger.info('Converting base models.')
|
||||
load_models(args, base_models)
|
||||
|
||||
if args.extras:
|
||||
logger.info('Converting extra models.')
|
||||
load_models(args, extra_models)
|
||||
for file in args.extras:
|
||||
logger.info('Loading extra models from %s', file)
|
||||
try:
|
||||
with open(file, 'r') as f:
|
||||
data = loads(f.read())
|
||||
logger.info('Converting extra models.')
|
||||
load_models(args, data)
|
||||
except Exception as err:
|
||||
logger.error('Error converting extra models: %s', err)
|
||||
|
||||
return 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue