fix(api): add extra models to convert script
This commit is contained in:
parent
b1e7ab0a3e
commit
e0834110fc
10
README.md
10
README.md
|
@ -252,6 +252,8 @@ already listed in the `convert.py` script, including:
|
|||
- https://huggingface.co/runwayml/stable-diffusion-inpainting
|
||||
- https://huggingface.co/stabilityai/stable-diffusion-2-1
|
||||
- https://huggingface.co/stabilityai/stable-diffusion-2-inpainting
|
||||
- https://huggingface.co/Aybeeceedee/knollingcase
|
||||
- https://huggingface.co/prompthero/openjourney
|
||||
|
||||
You will need at least one of the base models for txt2img and img2img mode. If you want to use inpainting, you will
|
||||
also need one of the inpainting models. The upscaling and face correction models are downloaded from Github by the
|
||||
|
@ -273,15 +275,15 @@ paste it into the prompt.
|
|||
Run the provided conversion script from the `api/` directory:
|
||||
|
||||
```shell
|
||||
> python -m onnx_web.convert --diffusers --gfpgan --resrgan
|
||||
> python -m onnx_web.convert --diffusion --correction --upscaling
|
||||
```
|
||||
|
||||
Models that have already been downloaded and converted will be skipped, so it should be safe to run this script after
|
||||
every update. Some additional, more specialized models are available using the `--extras` flag.
|
||||
|
||||
The conversion script has a few other options, which can be printed using `python -m onnx_web.convert --help`. If you
|
||||
are using CUDA on Nvidia hardware, using the `--half` option may make things faster.
|
||||
|
||||
Models that have already been downloaded and converted will be skipped, so it should be safe to run this script after
|
||||
every update.
|
||||
|
||||
This will take a little while to convert each model. Stable diffusion v1.4 is about 6GB, v1.5 is at least 10GB or so.
|
||||
You can skip certain models by including a `--skip name` argument if you want to save time or disk space. For example,
|
||||
using `--skip stable-diffusion-onnx-v2-inpainting --skip stable-diffusion-onnx-v2-1` will not download the Stable
|
||||
|
|
|
@ -8,12 +8,15 @@ from pathlib import Path
|
|||
from shutil import copyfile, rmtree
|
||||
from sys import exit
|
||||
from torch.onnx import export
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
sources: Dict[str, List[Tuple[str, str]]] = {
|
||||
'diffusers': [
|
||||
Models = Dict[str, List[Tuple[str, str, Union[int, None]]]]
|
||||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
'diffusion': [
|
||||
# v1.x
|
||||
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
|
||||
('stable-diffusion-onnx-v1-inpainting',
|
||||
|
@ -23,11 +26,11 @@ sources: Dict[str, List[Tuple[str, str]]] = {
|
|||
('stable-diffusion-onnx-v2-inpainting',
|
||||
'stabilityai/stable-diffusion-2-inpainting'),
|
||||
],
|
||||
'gfpgan': [
|
||||
'correction': [
|
||||
('correction-gfpgan-v1-3',
|
||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4),
|
||||
],
|
||||
'real_esrgan': [
|
||||
'upscaling': [
|
||||
('upscaling-real-esrgan-x2-plus',
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2),
|
||||
('upscaling-real-esrgan-x4-plus',
|
||||
|
@ -37,6 +40,16 @@ sources: Dict[str, List[Tuple[str, str]]] = {
|
|||
],
|
||||
}
|
||||
|
||||
# other neat models
|
||||
extra_models: Models = {
|
||||
'diffusion': [
|
||||
('diffusion-knollingcase', 'Aybeeceedee/knollingcase'),
|
||||
('diffusion-openjourney', 'prompthero/openjourney'),
|
||||
],
|
||||
'correction': [],
|
||||
'upscaling': [],
|
||||
}
|
||||
|
||||
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
||||
path.join('..', 'models'))
|
||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
@ -175,7 +188,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool):
|
|||
'''
|
||||
dtype = torch.float16 if half else torch.float32
|
||||
dest_path = path.join(model_path, name)
|
||||
print('converting Diffusers model: %s -> %s' % (name, dest_path))
|
||||
|
||||
# diffusers go into a directory rather than .onnx file
|
||||
print('converting Diffusers model: %s -> %s/' % (name, dest_path))
|
||||
|
||||
if path.isdir(dest_path):
|
||||
print('ONNX model already exists, skipping.')
|
||||
|
@ -354,6 +369,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool):
|
|||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
|
||||
print('exporting ONNX model')
|
||||
|
||||
onnx_pipeline.save_pretrained(output_path)
|
||||
print("ONNX pipeline saved to", output_path)
|
||||
|
||||
|
@ -364,15 +381,39 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool):
|
|||
print("ONNX pipeline is loadable")
|
||||
|
||||
|
||||
def load_models(args, models: Models):
|
||||
if args.diffusion:
|
||||
for source in models.get('diffusion'):
|
||||
if source[0] in args.skip:
|
||||
print('Skipping model: %s' % source[0])
|
||||
else:
|
||||
convert_diffuser(*source, args.opset, args.half)
|
||||
|
||||
if args.upscaling:
|
||||
for source in models.get('upscaling'):
|
||||
if source[0] in args.skip:
|
||||
print('Skipping model: %s' % source[0])
|
||||
else:
|
||||
convert_real_esrgan(*source, args.opset)
|
||||
|
||||
if args.correction:
|
||||
for source in models.get('correction'):
|
||||
if source[0] in args.skip:
|
||||
print('Skipping model: %s' % source[0])
|
||||
else:
|
||||
convert_gfpgan(*source, args.opset)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = ArgumentParser(
|
||||
prog='onnx-web model converter',
|
||||
description='convert checkpoint models to ONNX')
|
||||
|
||||
# model groups
|
||||
parser.add_argument('--diffusers', action='store_true', default=False)
|
||||
parser.add_argument('--gfpgan', action='store_true', default=False)
|
||||
parser.add_argument('--resrgan', action='store_true', default=False)
|
||||
parser.add_argument('--correction', action='store_true', default=False)
|
||||
parser.add_argument('--diffusion', action='store_true', default=False)
|
||||
parser.add_argument('--extras', action='store_true', default=False)
|
||||
parser.add_argument('--upscaling', action='store_true', default=False)
|
||||
parser.add_argument('--skip', nargs='*', type=str, default=[])
|
||||
|
||||
# export options
|
||||
|
@ -392,26 +433,12 @@ def main() -> int:
|
|||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.diffusers:
|
||||
for source in sources.get('diffusers'):
|
||||
if source[0] in args.skip:
|
||||
print('Skipping model: %s' % source[0])
|
||||
else:
|
||||
convert_diffuser(*source, args.opset, args.half)
|
||||
print('Converting base models.')
|
||||
load_models(args, base_models)
|
||||
|
||||
if args.resrgan:
|
||||
for source in sources.get('real_esrgan'):
|
||||
if source[0] in args.skip:
|
||||
print('Skipping model: %s' % source[0])
|
||||
else:
|
||||
convert_real_esrgan(*source, args.opset)
|
||||
|
||||
if args.gfpgan:
|
||||
for source in sources.get('gfpgan'):
|
||||
if source[0] in args.skip:
|
||||
print('Skipping model: %s' % source[0])
|
||||
else:
|
||||
convert_gfpgan(*source, args.opset)
|
||||
if args.extras:
|
||||
print('Converting extra models.')
|
||||
load_models(args, extra_models)
|
||||
|
||||
return 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue