feat(api): add params for more SwinIR models
This commit is contained in:
parent
23fb752bb6
commit
2a7621c195
|
@ -30,7 +30,7 @@ from .image import (
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
valid_image,
|
valid_image,
|
||||||
)
|
)
|
||||||
from .onnx import OnnxNet, OnnxTensor
|
from .onnx import OnnxRRDBNet, OnnxTensor
|
||||||
from .params import (
|
from .params import (
|
||||||
Border,
|
Border,
|
||||||
DeviceParams,
|
DeviceParams,
|
||||||
|
|
|
@ -26,10 +26,10 @@ def load_bsrgan(
|
||||||
cache_pipe = server.cache.get("bsrgan", cache_key)
|
cache_pipe = server.cache.get("bsrgan", cache_key)
|
||||||
|
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.info("reusing existing BSRGAN pipeline")
|
logger.debug("reusing existing BSRGAN pipeline")
|
||||||
return cache_pipe
|
return cache_pipe
|
||||||
|
|
||||||
logger.debug("loading BSRGAN model from %s", model_path)
|
logger.info("loading BSRGAN model from %s", model_path)
|
||||||
|
|
||||||
pipe = OnnxModel(
|
pipe = OnnxModel(
|
||||||
server,
|
server,
|
||||||
|
@ -62,7 +62,7 @@ def upscale_bsrgan(
|
||||||
logger.warn("no upscaling model given, skipping")
|
logger.warn("no upscaling model given, skipping")
|
||||||
return source
|
return source
|
||||||
|
|
||||||
logger.info("correcting faces with BSRGAN model: %s", upscale.upscale_model)
|
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
bsrgan = load_bsrgan(server, stage, upscale, device)
|
bsrgan = load_bsrgan(server, stage, upscale, device)
|
||||||
|
|
||||||
|
@ -73,13 +73,13 @@ def upscale_bsrgan(
|
||||||
image = np.array(source) / 255.0
|
image = np.array(source) / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.info("BSRGAN input shape: %s", image.shape)
|
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
dest = np.zeros(
|
||||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
||||||
)
|
)
|
||||||
logger.info("BSRGAN output shape: %s", dest.shape)
|
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||||
|
|
||||||
for x in range(tile_x):
|
for x in range(tile_x):
|
||||||
for y in range(tile_y):
|
for y in range(tile_y):
|
||||||
|
@ -90,7 +90,7 @@ def upscale_bsrgan(
|
||||||
ix2 = xt + tile_size[0]
|
ix2 = xt + tile_size[0]
|
||||||
iy1 = yt
|
iy1 = yt
|
||||||
iy2 = yt + tile_size[1]
|
iy2 = yt + tile_size[1]
|
||||||
logger.info(
|
logger.debug(
|
||||||
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||||
ix1,
|
ix1,
|
||||||
ix2,
|
ix2,
|
||||||
|
@ -114,5 +114,5 @@ def upscale_bsrgan(
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
dest = (dest * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
output = Image.fromarray(dest, "RGB")
|
output = Image.fromarray(dest, "RGB")
|
||||||
logger.info("output image size: %s x %s", output.width, output.height)
|
logger.debug("output image size: %s x %s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -5,7 +5,8 @@ from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..onnx import OnnxNet
|
from ..models.rrdb import RRDBNet
|
||||||
|
from ..onnx import OnnxRRDBNet
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
@ -20,7 +21,6 @@ def load_resrgan(
|
||||||
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||||
):
|
):
|
||||||
# must be within load function for patches to take effect
|
# must be within load function for patches to take effect
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def load_resrgan(
|
||||||
|
|
||||||
if params.format == "onnx":
|
if params.format == "onnx":
|
||||||
# use ONNX acceleration, if available
|
# use ONNX acceleration, if available
|
||||||
model = OnnxNet(
|
model = OnnxRRDBNet(
|
||||||
server,
|
server,
|
||||||
model_file,
|
model_file,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
|
|
|
@ -66,10 +66,12 @@ def upscale_swinir(
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
swinir = load_swinir(server, stage, upscale, device)
|
swinir = load_swinir(server, stage, upscale, device)
|
||||||
|
|
||||||
|
# TODO: add support for other sizes
|
||||||
tile_size = (64, 64)
|
tile_size = (64, 64)
|
||||||
tile_x = source.width // tile_size[0]
|
tile_x = source.width // tile_size[0]
|
||||||
tile_y = source.height // tile_size[1]
|
tile_y = source.height // tile_size[1]
|
||||||
|
|
||||||
|
# TODO: add support for grayscale (1-channel) images
|
||||||
image = np.array(source) / 255.0
|
image = np.array(source) / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
|
|
|
@ -122,6 +122,12 @@ base_models: Models = {
|
||||||
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
|
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
|
||||||
"scale": 4,
|
"scale": 4,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model": "bsrgan",
|
||||||
|
"name": "upscaling-bsrgan-x2",
|
||||||
|
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGANx2.pth",
|
||||||
|
"scale": 2,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
# download only
|
# download only
|
||||||
"sources": [
|
"sources": [
|
||||||
|
|
|
@ -2,9 +2,9 @@ from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
|
from ...models.rrdb import RRDBNet
|
||||||
from ..utils import ConversionContext, ModelDict
|
from ..utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -28,6 +28,7 @@ def convert_upscaling_bsrgan(
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("loading and training model")
|
logger.info("loading and training model")
|
||||||
|
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
|
||||||
model = RRDBNet(
|
model = RRDBNet(
|
||||||
in_nc=3,
|
in_nc=3,
|
||||||
out_nc=3,
|
out_nc=3,
|
||||||
|
|
|
@ -4,6 +4,7 @@ from os import path
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
|
from ...models.rrdb import RRDBNet
|
||||||
from ..utils import ConversionContext, ModelDict
|
from ..utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -17,7 +18,6 @@ def convert_upscale_resrgan(
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
):
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
|
|
|
@ -28,19 +28,50 @@ def convert_upscaling_swinir(
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("loading and training model")
|
logger.info("loading and training model")
|
||||||
img_size = (64, 64) # TODO: does this need to be a fixed value?
|
# values based on https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py#L128
|
||||||
|
params = {
|
||||||
|
"depths": [6, 6, 6, 6, 6, 6],
|
||||||
|
"embed_dim": 180,
|
||||||
|
"img_range": 1.0,
|
||||||
|
"img_size": (64, 64),
|
||||||
|
"in_chans": 3,
|
||||||
|
"num_heads": [6, 6, 6, 6, 6, 6],
|
||||||
|
"resi_connection": "1conv",
|
||||||
|
"upsampler": "pixelshuffle",
|
||||||
|
"window_size": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "lightweight" in name:
|
||||||
|
logger.debug("using SwinIR lightweight params")
|
||||||
|
params["depths"] = [6, 6, 6, 6]
|
||||||
|
params["embed_dim"] = 60
|
||||||
|
params["num_heads"] = [6, 6, 6, 6]
|
||||||
|
params["upsampler"] = "pixelshuffledirect"
|
||||||
|
elif "real" in name:
|
||||||
|
# TODO: add params for large model
|
||||||
|
logger.debug("using SwinIR real params")
|
||||||
|
params["upsampler"] = "nearest+conv"
|
||||||
|
elif "gray_dn" in name:
|
||||||
|
params["img_size"] = (128, 128)
|
||||||
|
params["in_chans"] = 1
|
||||||
|
params["upsampler"] = ""
|
||||||
|
elif "color_dn" in name:
|
||||||
|
params["img_size"] = (128, 128)
|
||||||
|
params["upsampler"] = ""
|
||||||
|
elif "gray_jpeg" in name:
|
||||||
|
params["img_size"] = (126, 126)
|
||||||
|
params["in_chans"] = 1
|
||||||
|
params["upsampler"] = ""
|
||||||
|
params["window_size"] = 7
|
||||||
|
elif "color_jpeg" in name:
|
||||||
|
params["img_size"] = (126, 126)
|
||||||
|
params["upsampler"] = ""
|
||||||
|
params["window_size"] = 7
|
||||||
|
|
||||||
model = SwinIR(
|
model = SwinIR(
|
||||||
depths=[6, 6, 6, 6, 6, 6],
|
|
||||||
embed_dim=180,
|
|
||||||
img_range=1.0,
|
|
||||||
img_size=img_size,
|
|
||||||
in_chans=3,
|
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
num_heads=[6, 6, 6, 6, 6, 6],
|
|
||||||
resi_connection="1conv",
|
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
upsampler="pixelshuffle",
|
**params,
|
||||||
window_size=8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from .onnx_net import OnnxTensor, OnnxNet
|
from .onnx_net import OnnxTensor, OnnxRRDBNet
|
||||||
|
|
|
@ -37,9 +37,9 @@ class OnnxTensor:
|
||||||
return np.shape(self.source)
|
return np.shape(self.source)
|
||||||
|
|
||||||
|
|
||||||
class OnnxNet:
|
class OnnxRRDBNet:
|
||||||
"""
|
"""
|
||||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
Provides the RRDBNet interface using an ONNX session.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
Loading…
Reference in New Issue