1
0
Fork 0

feat(api): add params for more SwinIR models

This commit is contained in:
Sean Sube 2023-04-10 19:02:12 -05:00
parent 23fb752bb6
commit 2a7621c195
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 66 additions and 26 deletions

View File

@ -30,7 +30,7 @@ from .image import (
noise_source_uniform,
valid_image,
)
from .onnx import OnnxNet, OnnxTensor
from .onnx import OnnxRRDBNet, OnnxTensor
from .params import (
Border,
DeviceParams,

View File

@ -26,10 +26,10 @@ def load_bsrgan(
cache_pipe = server.cache.get("bsrgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing BSRGAN pipeline")
logger.debug("reusing existing BSRGAN pipeline")
return cache_pipe
logger.debug("loading BSRGAN model from %s", model_path)
logger.info("loading BSRGAN model from %s", model_path)
pipe = OnnxModel(
server,
@ -62,7 +62,7 @@ def upscale_bsrgan(
logger.warn("no upscaling model given, skipping")
return source
logger.info("correcting faces with BSRGAN model: %s", upscale.upscale_model)
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
bsrgan = load_bsrgan(server, stage, upscale, device)
@ -73,13 +73,13 @@ def upscale_bsrgan(
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("BSRGAN input shape: %s", image.shape)
logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
)
logger.info("BSRGAN output shape: %s", dest.shape)
logger.trace("BSRGAN output shape: %s", dest.shape)
for x in range(tile_x):
for y in range(tile_y):
@ -90,7 +90,7 @@ def upscale_bsrgan(
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
logger.debug(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
@ -114,5 +114,5 @@ def upscale_bsrgan(
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
logger.debug("output image size: %s x %s", output.width, output.height)
return output

View File

@ -5,7 +5,8 @@ from typing import Optional
import numpy as np
from PIL import Image
from ..onnx import OnnxNet
from ..models.rrdb import RRDBNet
from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
@ -20,7 +21,6 @@ def load_resrgan(
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
@ -38,7 +38,7 @@ def load_resrgan(
if params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),

View File

@ -66,10 +66,12 @@ def upscale_swinir(
device = job.get_device()
swinir = load_swinir(server, stage, upscale, device)
# TODO: add support for other sizes
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
# TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)

View File

@ -122,6 +122,12 @@ base_models: Models = {
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
"scale": 4,
},
{
"model": "bsrgan",
"name": "upscaling-bsrgan-x2",
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGANx2.pth",
"scale": 2,
},
],
# download only
"sources": [

View File

@ -2,9 +2,9 @@ from logging import getLogger
from os import path
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)

View File

@ -28,6 +28,7 @@ def convert_upscaling_bsrgan(
return
logger.info("loading and training model")
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
model = RRDBNet(
in_nc=3,
out_nc=3,

View File

@ -4,6 +4,7 @@ from os import path
import torch
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@ -17,7 +18,6 @@ def convert_upscale_resrgan(
model: ModelDict,
source: str,
):
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
name = model.get("name")

View File

@ -28,19 +28,50 @@ def convert_upscaling_swinir(
return
logger.info("loading and training model")
img_size = (64, 64) # TODO: does this need to be a fixed value?
# values based on https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py#L128
params = {
"depths": [6, 6, 6, 6, 6, 6],
"embed_dim": 180,
"img_range": 1.0,
"img_size": (64, 64),
"in_chans": 3,
"num_heads": [6, 6, 6, 6, 6, 6],
"resi_connection": "1conv",
"upsampler": "pixelshuffle",
"window_size": 8,
}
if "lightweight" in name:
logger.debug("using SwinIR lightweight params")
params["depths"] = [6, 6, 6, 6]
params["embed_dim"] = 60
params["num_heads"] = [6, 6, 6, 6]
params["upsampler"] = "pixelshuffledirect"
elif "real" in name:
# TODO: add params for large model
logger.debug("using SwinIR real params")
params["upsampler"] = "nearest+conv"
elif "gray_dn" in name:
params["img_size"] = (128, 128)
params["in_chans"] = 1
params["upsampler"] = ""
elif "color_dn" in name:
params["img_size"] = (128, 128)
params["upsampler"] = ""
elif "gray_jpeg" in name:
params["img_size"] = (126, 126)
params["in_chans"] = 1
params["upsampler"] = ""
params["window_size"] = 7
elif "color_jpeg" in name:
params["img_size"] = (126, 126)
params["upsampler"] = ""
params["window_size"] = 7
model = SwinIR(
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
img_range=1.0,
img_size=img_size,
in_chans=3,
mlp_ratio=2,
num_heads=[6, 6, 6, 6, 6, 6],
resi_connection="1conv",
upscale=scale,
upsampler="pixelshuffle",
window_size=8,
**params,
)
torch_model = torch.load(source, map_location=conversion.map_location)

View File

@ -1 +1 @@
from .onnx_net import OnnxTensor, OnnxNet
from .onnx_net import OnnxTensor, OnnxRRDBNet

View File

@ -37,9 +37,9 @@ class OnnxTensor:
return np.shape(self.source)
class OnnxNet:
class OnnxRRDBNet:
"""
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
Provides the RRDBNet interface using an ONNX session.
"""
def __init__(