1
0
Fork 0

feat(api): add initial support for BSRGAN and SwinIR upscaling (#153, #154)

This commit is contained in:
Sean Sube 2023-04-10 17:49:56 -05:00
parent 464bfd01b8
commit 62aa7e8473
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
18 changed files with 1697 additions and 11 deletions

View File

@ -10,9 +10,11 @@ from .reduce_crop import reduce_crop
from .reduce_thumbnail import reduce_thumbnail
from .source_noise import source_noise
from .source_txt2img import source_txt2img
from .upscale_bsrgan import upscale_bsrgan
from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion
from .upscale_swinir import upscale_swinir
CHAIN_STAGES = {
"blend-img2img": blend_img2img,
@ -26,7 +28,9 @@ CHAIN_STAGES = {
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-txt2img": source_txt2img,
"upscale-bsrgan": upscale_bsrgan,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
"upscale-swinir": upscale_swinir,
}

View File

@ -0,0 +1,118 @@
from logging import getLogger
from os import path
from typing import Optional
import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
logger = getLogger(__name__)
def load_bsrgan(
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing BSRGAN pipeline")
return cache_pipe
logger.debug("loading BSRGAN model from %s", model_path)
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
server.cache.set("bsrgan", cache_key, pipe)
run_gc([device])
return pipe
def upscale_bsrgan(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None:
logger.warn("no upscaling model given, skipping")
return source
logger.info("correcting faces with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
bsrgan = load_bsrgan(server, stage, upscale, device)
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
)
logger.info("BSRGAN output shape: %s", dest.shape)
for x in range(tile_x):
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
return output

View File

@ -0,0 +1,118 @@
from logging import getLogger
from os import path
from typing import Optional
import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
logger = getLogger(__name__)
def load_swinir(
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key)
if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline")
return cache_pipe
logger.debug("loading SwinIR model from %s", model_path)
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
server.cache.set("swinir", cache_key, pipe)
run_gc([device])
return pipe
def upscale_swinir(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None:
logger.warn("no correction model given, skipping")
return source
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device()
swinir = load_swinir(server, stage, upscale, device)
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("SwinIR input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
)
logger.info("SwinIR output shape: %s", dest.shape)
for x in range(tile_x):
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
return output

View File

@ -13,12 +13,14 @@ from transformers import CLIPTokenizer
from yaml import safe_load
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from .correction_gfpgan import convert_correction_gfpgan
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.lora import blend_loras
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import blend_textual_inversions
from .upscale_resrgan import convert_upscale_resrgan
from .upscaling.bsrgan import convert_upscaling_bsrgan
from .upscaling.resrgan import convert_upscale_resrgan
from .upscaling.swinir import convert_upscaling_swinir
from .utils import (
ConversionContext,
download_progress,
@ -108,6 +110,18 @@ base_models: Models = {
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
4,
),
{
"model": "swinir",
"name": "upscaling-swinir-classical-x4",
"source": "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth",
"scale": 4,
},
{
"model": "bsrgan",
"name": "upscaling-bsrgan-x4",
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
"scale": 4,
},
],
# download only
"sources": [
@ -415,7 +429,17 @@ def convert_models(conversion: ConversionContext, args, models: Models):
source = fetch_model(
conversion, name, model["source"], format=model_format
)
convert_upscale_resrgan(conversion, model, source)
model_type = model.get("model", "resrgan")
if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan":
convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source)
else:
logger.error(
"unknown upscaling model type %s for %s", model_type, name
)
except Exception:
logger.exception(
"error converting upscaling model %s",
@ -435,7 +459,13 @@ def convert_models(conversion: ConversionContext, args, models: Models):
source = fetch_model(
conversion, name, model["source"], format=model_format
)
convert_correction_gfpgan(conversion, model, source)
model_type = model.get("model", "gfpgan")
if model_type == "gfpgan":
convert_correction_gfpgan(conversion, model, source)
else:
logger.error(
"unknown correction model type %s for %s", model_type, name
)
except Exception:
logger.exception(
"error converting correction model %s",

View File

@ -5,7 +5,7 @@ import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from torch.onnx import export
from .utils import ConversionContext, ModelDict
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@ -18,7 +18,7 @@ def convert_correction_gfpgan(
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")
scale = model.get("scale", 1)
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest)

View File

@ -277,6 +277,7 @@ def blend_loras(
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
logger.trace("fixed node names: %s", fixed_node_names)
unmatched_keys = []
for base_key, weights in blended.items():
conv_key = base_key + "_Conv"
gemm_key = base_key + "_Gemm"
@ -377,7 +378,7 @@ def blend_loras(
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:
logger.warning("could not find any nodes for %s", base_key)
unmatched_keys.append(base_key)
logger.debug(
"node counts: %s -> %s, %s -> %s",
@ -387,6 +388,9 @@ def blend_loras(
len(base_model.graph.node),
)
if len(unmatched_keys) > 0:
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
return base_model

View File

@ -0,0 +1,70 @@
from logging import getLogger
from os import path
import torch
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@torch.no_grad()
def convert_upscaling_bsrgan(
conversion: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale", 1)
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting BSRGAN model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping")
return
logger.info("loading and training model")
model = RRDBNet(
in_nc=3,
out_nc=3,
nf=64,
nb=23,
gc=32,
sf=scale,
)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
else:
model.load_state_dict(torch_model, strict=False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {
"input": {2: "h", 3: "w"},
"output": {2: "h", 3: "w"},
}
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=conversion.opset,
export_params=True,
)
logger.info("BSRGAN exported to ONNX successfully")

View File

@ -4,7 +4,7 @@ from os import path
import torch
from torch.onnx import export
from .utils import ConversionContext, ModelDict
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)

View File

@ -0,0 +1,74 @@
from logging import getLogger
from os import path
import torch
from torch.onnx import export
from ...models.swinir import SwinIR
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@torch.no_grad()
def convert_upscaling_swinir(
conversion: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale", 1)
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting SwinIR model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping")
return
logger.info("loading and training model")
img_size = (64, 64) # TODO: does this need to be a fixed value?
model = SwinIR(
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
img_range=1.0,
img_size=img_size,
in_chans=3,
mlp_ratio=2,
num_heads=[6, 6, 6, 6, 6, 6],
resi_connection="1conv",
upscale=scale,
upsampler="pixelshuffle",
window_size=8,
)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {
"input": {2: "h", 3: "w"},
"output": {2: "h", 3: "w"},
}
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=conversion.opset,
export_params=True,
)
logger.info("SwinIR exported to ONNX successfully")

View File

@ -7,8 +7,10 @@ from ..chain import (
ChainPipeline,
correct_codeformer,
correct_gfpgan,
upscale_bsrgan,
upscale_resrgan,
upscale_stable_diffusion,
upscale_swinir,
)
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ServerContext
@ -41,15 +43,28 @@ def run_upscale_correction(
upscale_stage = None
if upscale.scale > 1:
if "esrgan" in upscale.upscale_model:
if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_bsrgan, bsrgan_params, None)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.outscale
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_resrgan, esrgan_params, None)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, None)
elif "swinir" in upscale.correction_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_swinir, swinir_params, None)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)

View File

@ -0,0 +1,28 @@
from logging import getLogger
from os import path
from typing import Any, Optional
from ..server import ServerContext
from ..torch_before_ort import InferenceSession, SessionOptions
logger = getLogger(__name__)
class OnnxModel:
def __init__(
self,
server: ServerContext,
model: str,
provider: str = "DmlExecutionProvider",
sess_options: Optional[SessionOptions] = None,
) -> None:
model_path = path.join(server.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider], provider_options=sess_options
)
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {input_name: image})[0]
return output

112
api/onnx_web/models/rrdb.py Normal file
View File

@ -0,0 +1,112 @@
# from https://github.com/cszn/BSRGAN/blob/main/models/network_rrdbnet.py
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode="fan_in")
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode="fan_in")
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
initialize_weights(
[self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1
)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block"""
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
print([in_nc, out_nc, nf, nb, gc, sf])
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.sf == 4:
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
)
if self.sf == 4:
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
)
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out

File diff suppressed because it is too large Load Diff

View File

@ -20,7 +20,7 @@ from ..image import ( # mask filters; noise sources
noise_source_normal,
noise_source_uniform,
)
from ..models import NetworkModel
from ..models.meta import NetworkModel
from ..params import DeviceParams
from ..torch_before_ort import get_available_providers
from ..utils import merge

View File

@ -17,6 +17,7 @@ codeformer-perceptor
facexlib
gfpgan
realesrgan
timm
### Server packages ###
boto3

View File

@ -62,6 +62,14 @@ $defs:
correction_model:
allOf:
- $ref: "#/$defs/base_model"
- type: object
properties:
model:
type: enum
enum: [
codeformer,
gfpgan
]
diffusion_model:
allOf:
@ -87,6 +95,13 @@ $defs:
- type: object
required: [scale]
properties:
model:
type: enum
enum: [
bsrgan,
resrgan,
swinir
]
scale:
type: number

View File

@ -81,10 +81,18 @@ export const I18N_STRINGS_EN = {
'inversion-minecraft': 'Minecraft Concept',
'inversion-ugly-sonic': 'Ugly Sonic',
// upscaling
'upscaling-bsrgan-x2': 'BSRGAN x2',
'upscaling-bsrgan-x4': 'BSRGAN x4',
'upscaling-real-esrgan-x2-plus': 'Real ESRGAN x2 Plus',
'upscaling-real-esrgan-x4-plus': 'Real ESRGAN x4 Plus',
'upscaling-real-esrgan-x4-v3': 'Real ESRGAN x4 v3',
'upscaling-stable-diffusion-x4': 'Stable Diffusion x4',
'upscaling-swinir-classical-x2': 'SwinIR Classical x2',
'upscaling-swinir-classical-x3': 'SwinIR Classical x3',
'upscaling-swinir-classical-x4': 'SwinIR Classical x4',
'upscaling-swinir-classical-x8': 'SwinIR Classical x8',
'upscaling-swinir-real-x2': 'SwinIR Real x2',
'upscaling-swinir-real-x4': 'SwinIR Real x4',
// extras
'diffusion-stablydiffused-aesthetic-v2-6': 'Aesthetic Mix v2.6',
'diffusion-anything': 'Anything',