From 2a7621c195c01000badf74073fd95803fffbb99a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 10 Apr 2023 19:02:12 -0500 Subject: [PATCH] feat(api): add params for more SwinIR models --- api/onnx_web/__init__.py | 2 +- api/onnx_web/chain/upscale_bsrgan.py | 14 +++---- api/onnx_web/chain/upscale_resrgan.py | 6 +-- api/onnx_web/chain/upscale_swinir.py | 2 + api/onnx_web/convert/__main__.py | 6 +++ api/onnx_web/convert/correction/gfpgan.py | 2 +- api/onnx_web/convert/upscaling/bsrgan.py | 1 + api/onnx_web/convert/upscaling/resrgan.py | 2 +- api/onnx_web/convert/upscaling/swinir.py | 51 ++++++++++++++++++----- api/onnx_web/onnx/__init__.py | 2 +- api/onnx_web/onnx/onnx_net.py | 4 +- 11 files changed, 66 insertions(+), 26 deletions(-) diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 9cdd7cd4..9697d7b3 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -30,7 +30,7 @@ from .image import ( noise_source_uniform, valid_image, ) -from .onnx import OnnxNet, OnnxTensor +from .onnx import OnnxRRDBNet, OnnxTensor from .params import ( Border, DeviceParams, diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index cdcfe1ba..37350d09 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -26,10 +26,10 @@ def load_bsrgan( cache_pipe = server.cache.get("bsrgan", cache_key) if cache_pipe is not None: - logger.info("reusing existing BSRGAN pipeline") + logger.debug("reusing existing BSRGAN pipeline") return cache_pipe - logger.debug("loading BSRGAN model from %s", model_path) + logger.info("loading BSRGAN model from %s", model_path) pipe = OnnxModel( server, @@ -62,7 +62,7 @@ def upscale_bsrgan( logger.warn("no upscaling model given, skipping") return source - logger.info("correcting faces with BSRGAN model: %s", upscale.upscale_model) + logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model) device = job.get_device() bsrgan = load_bsrgan(server, stage, upscale, device) @@ -73,13 +73,13 @@ def upscale_bsrgan( image = np.array(source) / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) - logger.info("BSRGAN input shape: %s", image.shape) + logger.trace("BSRGAN input shape: %s", image.shape) scale = upscale.outscale dest = np.zeros( (image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale) ) - logger.info("BSRGAN output shape: %s", dest.shape) + logger.trace("BSRGAN output shape: %s", dest.shape) for x in range(tile_x): for y in range(tile_y): @@ -90,7 +90,7 @@ def upscale_bsrgan( ix2 = xt + tile_size[0] iy1 = yt iy2 = yt + tile_size[1] - logger.info( + logger.debug( "running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)", ix1, ix2, @@ -114,5 +114,5 @@ def upscale_bsrgan( dest = (dest * 255.0).round().astype(np.uint8) output = Image.fromarray(dest, "RGB") - logger.info("output image size: %s x %s", output.width, output.height) + logger.debug("output image size: %s x %s", output.width, output.height) return output diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 1f8c8f07..e84ba3de 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -5,7 +5,8 @@ from typing import Optional import numpy as np from PIL import Image -from ..onnx import OnnxNet +from ..models.rrdb import RRDBNet +from ..onnx import OnnxRRDBNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..utils import run_gc @@ -20,7 +21,6 @@ def load_resrgan( server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 ): # must be within load function for patches to take effect - from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact @@ -38,7 +38,7 @@ def load_resrgan( if params.format == "onnx": # use ONNX acceleration, if available - model = OnnxNet( + model = OnnxRRDBNet( server, model_file, provider=device.ort_provider(), diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index 8e96d773..aaf466cc 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -66,10 +66,12 @@ def upscale_swinir( device = job.get_device() swinir = load_swinir(server, stage, upscale, device) + # TODO: add support for other sizes tile_size = (64, 64) tile_x = source.width // tile_size[0] tile_y = source.height // tile_size[1] + # TODO: add support for grayscale (1-channel) images image = np.array(source) / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 2aeb3f58..d6ed1db0 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -122,6 +122,12 @@ base_models: Models = { "source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth", "scale": 4, }, + { + "model": "bsrgan", + "name": "upscaling-bsrgan-x2", + "source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGANx2.pth", + "scale": 2, + }, ], # download only "sources": [ diff --git a/api/onnx_web/convert/correction/gfpgan.py b/api/onnx_web/convert/correction/gfpgan.py index 5f1d892f..0662754b 100644 --- a/api/onnx_web/convert/correction/gfpgan.py +++ b/api/onnx_web/convert/correction/gfpgan.py @@ -2,9 +2,9 @@ from logging import getLogger from os import path import torch -from basicsr.archs.rrdbnet_arch import RRDBNet from torch.onnx import export +from ...models.rrdb import RRDBNet from ..utils import ConversionContext, ModelDict logger = getLogger(__name__) diff --git a/api/onnx_web/convert/upscaling/bsrgan.py b/api/onnx_web/convert/upscaling/bsrgan.py index a757eb78..c8f1132f 100644 --- a/api/onnx_web/convert/upscaling/bsrgan.py +++ b/api/onnx_web/convert/upscaling/bsrgan.py @@ -28,6 +28,7 @@ def convert_upscaling_bsrgan( return logger.info("loading and training model") + # values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69 model = RRDBNet( in_nc=3, out_nc=3, diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index 96fc01b8..c06bae06 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -4,6 +4,7 @@ from os import path import torch from torch.onnx import export +from ...models.rrdb import RRDBNet from ..utils import ConversionContext, ModelDict logger = getLogger(__name__) @@ -17,7 +18,6 @@ def convert_upscale_resrgan( model: ModelDict, source: str, ): - from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan.archs.srvgg_arch import SRVGGNetCompact name = model.get("name") diff --git a/api/onnx_web/convert/upscaling/swinir.py b/api/onnx_web/convert/upscaling/swinir.py index 7124353a..b8b39f12 100644 --- a/api/onnx_web/convert/upscaling/swinir.py +++ b/api/onnx_web/convert/upscaling/swinir.py @@ -28,19 +28,50 @@ def convert_upscaling_swinir( return logger.info("loading and training model") - img_size = (64, 64) # TODO: does this need to be a fixed value? + # values based on https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py#L128 + params = { + "depths": [6, 6, 6, 6, 6, 6], + "embed_dim": 180, + "img_range": 1.0, + "img_size": (64, 64), + "in_chans": 3, + "num_heads": [6, 6, 6, 6, 6, 6], + "resi_connection": "1conv", + "upsampler": "pixelshuffle", + "window_size": 8, + } + + if "lightweight" in name: + logger.debug("using SwinIR lightweight params") + params["depths"] = [6, 6, 6, 6] + params["embed_dim"] = 60 + params["num_heads"] = [6, 6, 6, 6] + params["upsampler"] = "pixelshuffledirect" + elif "real" in name: + # TODO: add params for large model + logger.debug("using SwinIR real params") + params["upsampler"] = "nearest+conv" + elif "gray_dn" in name: + params["img_size"] = (128, 128) + params["in_chans"] = 1 + params["upsampler"] = "" + elif "color_dn" in name: + params["img_size"] = (128, 128) + params["upsampler"] = "" + elif "gray_jpeg" in name: + params["img_size"] = (126, 126) + params["in_chans"] = 1 + params["upsampler"] = "" + params["window_size"] = 7 + elif "color_jpeg" in name: + params["img_size"] = (126, 126) + params["upsampler"] = "" + params["window_size"] = 7 + model = SwinIR( - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - img_range=1.0, - img_size=img_size, - in_chans=3, mlp_ratio=2, - num_heads=[6, 6, 6, 6, 6, 6], - resi_connection="1conv", upscale=scale, - upsampler="pixelshuffle", - window_size=8, + **params, ) torch_model = torch.load(source, map_location=conversion.map_location) diff --git a/api/onnx_web/onnx/__init__.py b/api/onnx_web/onnx/__init__.py index 8e58b60a..2445ea19 100644 --- a/api/onnx_web/onnx/__init__.py +++ b/api/onnx_web/onnx/__init__.py @@ -1 +1 @@ -from .onnx_net import OnnxTensor, OnnxNet +from .onnx_net import OnnxTensor, OnnxRRDBNet diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 036c5d0d..7e43893b 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -37,9 +37,9 @@ class OnnxTensor: return np.shape(self.source) -class OnnxNet: +class OnnxRRDBNet: """ - Provides the RRDBNet interface using an ONNX session for DirectML acceleration. + Provides the RRDBNet interface using an ONNX session. """ def __init__(