diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 6a9d167d..5889646d 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -16,28 +16,30 @@ sources: Dict[str, List[Tuple[str, str]]] = { 'diffusers': [ # v1.x ('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'), - ('stable-diffusion-onnx-v1-inpainting', 'runwayml/stable-diffusion-inpainting'), + ('stable-diffusion-onnx-v1-inpainting', + 'runwayml/stable-diffusion-inpainting'), # v2.x ('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'), - ('stable-diffusion-onnx-v2-inpainting', 'stabilityai/stable-diffusion-2-inpainting'), + ('stable-diffusion-onnx-v2-inpainting', + 'stabilityai/stable-diffusion-2-inpainting'), ], 'gfpgan': [ - ('GFPGANv1.3', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'), + ('correction-gfpgan-v1-3', + 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'), ], 'real_esrgan': [ - ('RealESRGAN_x4plus', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'), + ('upscaling-real-esrgan-x4-plus', + 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4), ], } model_path = environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models')) - - training_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.no_grad() -def convert_real_esrgan(name: str, url: str, opset: int): +def convert_real_esrgan(name: str, url: str, scale: int, opset: int): dest_path = path.join(model_path, name) dest_onnx = path.join(model_path, name + '.onnx') print('converting Real ESRGAN model: %s -> %s' % (name, dest_path)) @@ -53,7 +55,7 @@ def convert_real_esrgan(name: str, url: str, opset: int): print('loading and training model') model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, - num_block=23, num_grow_ch=32, scale=4) + num_block=23, num_grow_ch=32, scale=scale) model.load_state_dict(torch.load(dest_path)['params_ema']) model.to(training_device).train(False) model.eval()