feat(api): add params for more SwinIR models
This commit is contained in:
parent
23fb752bb6
commit
2a7621c195
|
@ -30,7 +30,7 @@ from .image import (
|
|||
noise_source_uniform,
|
||||
valid_image,
|
||||
)
|
||||
from .onnx import OnnxNet, OnnxTensor
|
||||
from .onnx import OnnxRRDBNet, OnnxTensor
|
||||
from .params import (
|
||||
Border,
|
||||
DeviceParams,
|
||||
|
|
|
@ -26,10 +26,10 @@ def load_bsrgan(
|
|||
cache_pipe = server.cache.get("bsrgan", cache_key)
|
||||
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing BSRGAN pipeline")
|
||||
logger.debug("reusing existing BSRGAN pipeline")
|
||||
return cache_pipe
|
||||
|
||||
logger.debug("loading BSRGAN model from %s", model_path)
|
||||
logger.info("loading BSRGAN model from %s", model_path)
|
||||
|
||||
pipe = OnnxModel(
|
||||
server,
|
||||
|
@ -62,7 +62,7 @@ def upscale_bsrgan(
|
|||
logger.warn("no upscaling model given, skipping")
|
||||
return source
|
||||
|
||||
logger.info("correcting faces with BSRGAN model: %s", upscale.upscale_model)
|
||||
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
bsrgan = load_bsrgan(server, stage, upscale, device)
|
||||
|
||||
|
@ -73,13 +73,13 @@ def upscale_bsrgan(
|
|||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.info("BSRGAN input shape: %s", image.shape)
|
||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
||||
)
|
||||
logger.info("BSRGAN output shape: %s", dest.shape)
|
||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||
|
||||
for x in range(tile_x):
|
||||
for y in range(tile_y):
|
||||
|
@ -90,7 +90,7 @@ def upscale_bsrgan(
|
|||
ix2 = xt + tile_size[0]
|
||||
iy1 = yt
|
||||
iy2 = yt + tile_size[1]
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||
ix1,
|
||||
ix2,
|
||||
|
@ -114,5 +114,5 @@ def upscale_bsrgan(
|
|||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.info("output image size: %s x %s", output.width, output.height)
|
||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
||||
return output
|
||||
|
|
|
@ -5,7 +5,8 @@ from typing import Optional
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..onnx import OnnxNet
|
||||
from ..models.rrdb import RRDBNet
|
||||
from ..onnx import OnnxRRDBNet
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
|
@ -20,7 +21,6 @@ def load_resrgan(
|
|||
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||
):
|
||||
# must be within load function for patches to take effect
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
|
@ -38,7 +38,7 @@ def load_resrgan(
|
|||
|
||||
if params.format == "onnx":
|
||||
# use ONNX acceleration, if available
|
||||
model = OnnxNet(
|
||||
model = OnnxRRDBNet(
|
||||
server,
|
||||
model_file,
|
||||
provider=device.ort_provider(),
|
||||
|
|
|
@ -66,10 +66,12 @@ def upscale_swinir(
|
|||
device = job.get_device()
|
||||
swinir = load_swinir(server, stage, upscale, device)
|
||||
|
||||
# TODO: add support for other sizes
|
||||
tile_size = (64, 64)
|
||||
tile_x = source.width // tile_size[0]
|
||||
tile_y = source.height // tile_size[1]
|
||||
|
||||
# TODO: add support for grayscale (1-channel) images
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
|
|
@ -122,6 +122,12 @@ base_models: Models = {
|
|||
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
|
||||
"scale": 4,
|
||||
},
|
||||
{
|
||||
"model": "bsrgan",
|
||||
"name": "upscaling-bsrgan-x2",
|
||||
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGANx2.pth",
|
||||
"scale": 2,
|
||||
},
|
||||
],
|
||||
# download only
|
||||
"sources": [
|
||||
|
|
|
@ -2,9 +2,9 @@ from logging import getLogger
|
|||
from os import path
|
||||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from torch.onnx import export
|
||||
|
||||
from ...models.rrdb import RRDBNet
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -28,6 +28,7 @@ def convert_upscaling_bsrgan(
|
|||
return
|
||||
|
||||
logger.info("loading and training model")
|
||||
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
|
||||
model = RRDBNet(
|
||||
in_nc=3,
|
||||
out_nc=3,
|
||||
|
|
|
@ -4,6 +4,7 @@ from os import path
|
|||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from ...models.rrdb import RRDBNet
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -17,7 +18,6 @@ def convert_upscale_resrgan(
|
|||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
name = model.get("name")
|
||||
|
|
|
@ -28,19 +28,50 @@ def convert_upscaling_swinir(
|
|||
return
|
||||
|
||||
logger.info("loading and training model")
|
||||
img_size = (64, 64) # TODO: does this need to be a fixed value?
|
||||
# values based on https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py#L128
|
||||
params = {
|
||||
"depths": [6, 6, 6, 6, 6, 6],
|
||||
"embed_dim": 180,
|
||||
"img_range": 1.0,
|
||||
"img_size": (64, 64),
|
||||
"in_chans": 3,
|
||||
"num_heads": [6, 6, 6, 6, 6, 6],
|
||||
"resi_connection": "1conv",
|
||||
"upsampler": "pixelshuffle",
|
||||
"window_size": 8,
|
||||
}
|
||||
|
||||
if "lightweight" in name:
|
||||
logger.debug("using SwinIR lightweight params")
|
||||
params["depths"] = [6, 6, 6, 6]
|
||||
params["embed_dim"] = 60
|
||||
params["num_heads"] = [6, 6, 6, 6]
|
||||
params["upsampler"] = "pixelshuffledirect"
|
||||
elif "real" in name:
|
||||
# TODO: add params for large model
|
||||
logger.debug("using SwinIR real params")
|
||||
params["upsampler"] = "nearest+conv"
|
||||
elif "gray_dn" in name:
|
||||
params["img_size"] = (128, 128)
|
||||
params["in_chans"] = 1
|
||||
params["upsampler"] = ""
|
||||
elif "color_dn" in name:
|
||||
params["img_size"] = (128, 128)
|
||||
params["upsampler"] = ""
|
||||
elif "gray_jpeg" in name:
|
||||
params["img_size"] = (126, 126)
|
||||
params["in_chans"] = 1
|
||||
params["upsampler"] = ""
|
||||
params["window_size"] = 7
|
||||
elif "color_jpeg" in name:
|
||||
params["img_size"] = (126, 126)
|
||||
params["upsampler"] = ""
|
||||
params["window_size"] = 7
|
||||
|
||||
model = SwinIR(
|
||||
depths=[6, 6, 6, 6, 6, 6],
|
||||
embed_dim=180,
|
||||
img_range=1.0,
|
||||
img_size=img_size,
|
||||
in_chans=3,
|
||||
mlp_ratio=2,
|
||||
num_heads=[6, 6, 6, 6, 6, 6],
|
||||
resi_connection="1conv",
|
||||
upscale=scale,
|
||||
upsampler="pixelshuffle",
|
||||
window_size=8,
|
||||
**params,
|
||||
)
|
||||
|
||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
|
|
|
@ -1 +1 @@
|
|||
from .onnx_net import OnnxTensor, OnnxNet
|
||||
from .onnx_net import OnnxTensor, OnnxRRDBNet
|
||||
|
|
|
@ -37,9 +37,9 @@ class OnnxTensor:
|
|||
return np.shape(self.source)
|
||||
|
||||
|
||||
class OnnxNet:
|
||||
class OnnxRRDBNet:
|
||||
"""
|
||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||
Provides the RRDBNet interface using an ONNX session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
Loading…
Reference in New Issue