1
0
Fork 0

chore(api): remove some unused non-ONNX code paths

This commit is contained in:
Sean Sube 2023-04-11 08:26:21 -05:00
parent 1bfc7bee32
commit de61e388a0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 88 additions and 41 deletions

View File

@ -24,6 +24,7 @@ def correct_codeformer(
**kwargs,
) -> Image.Image:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
source = stage_source or source

View File

@ -20,6 +20,7 @@ def load_gfpgan(
device: DeviceParams,
):
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from gfpgan import GFPGANer
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))

View File

@ -5,7 +5,6 @@ from typing import Optional
import numpy as np
from PIL import Image
from ..models.rrdb import RRDBNet
from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
@ -21,8 +20,8 @@ def load_resrgan(
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
# TODO: rewrite and remove
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file)
@ -36,36 +35,13 @@ def load_resrgan(
if not path.isfile(model_path):
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
if params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
elif params.format == "pth":
if TAG_X4_V3 in model_file:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
else:
raise ValueError("unknown platform %s" % params.format)
# TODO: swap for regular RRDBNet after rewriting wrapper
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
dni_weight = None
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
@ -85,7 +61,7 @@ def load_resrgan(
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=params.half,
half=False, # TODO: use server optimizations
)
server.cache.set("resrgan", cache_key, upsampler)

View File

@ -5,6 +5,7 @@ import torch
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ...models.srvgg import SRVGGNetCompact
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@ -18,8 +19,6 @@ def convert_upscale_resrgan(
model: ModelDict,
source: str,
):
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")

View File

@ -0,0 +1,75 @@
from torch import nn as nn
from torch.nn import functional as F
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
It is a compact network structure, which performs upsampling in the last layer and no convolution is
conducted on the HR feature space.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_conv (int): Number of convolution layers in the body network. Default: 16.
upscale (int): Upsampling factor. Default: 4.
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
"""
def __init__(
self,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=16,
upscale=4,
act_type="prelu",
):
super(SRVGGNetCompact, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.act_type = act_type
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# the first activation
if act_type == "relu":
activation = nn.ReLU(inplace=True)
elif act_type == "prelu":
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == "leakyrelu":
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the body structure
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
# activation
if act_type == "relu":
activation = nn.ReLU(inplace=True)
elif act_type == "prelu":
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == "leakyrelu":
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the last conv
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
# upsample
self.upsampler = nn.PixelShuffle(upscale)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
out += base
return out

View File

@ -234,8 +234,7 @@ class UpscaleParams:
faces=True,
face_outscale: int = 1,
face_strength: float = 0.5,
format: Literal["onnx", "pth"] = "onnx",
half=False,
format: Literal["onnx", "pth"] = "onnx", # TODO: deprecated, remove
outscale: int = 1,
scale: int = 4,
pre_pad: int = 0,
@ -251,7 +250,6 @@ class UpscaleParams:
self.face_outscale = face_outscale
self.face_strength = face_strength
self.format = format
self.half = half
self.outscale = outscale
self.pre_pad = pre_pad
self.scale = scale
@ -267,7 +265,6 @@ class UpscaleParams:
face_outscale=self.face_outscale,
face_strength=self.face_strength,
format=self.format,
half=self.half,
outscale=scale,
scale=scale,
pre_pad=self.pre_pad,
@ -294,7 +291,6 @@ class UpscaleParams:
"face_outscale": self.face_outscale,
"face_strength": self.face_strength,
"format": self.format,
"half": self.half,
"outscale": self.outscale,
"pre_pad": self.pre_pad,
"scale": self.scale,
@ -311,7 +307,6 @@ class UpscaleParams:
kwargs.get("face_outscale", self.face_outscale),
kwargs.get("face_strength", self.face_strength),
kwargs.get("format", self.format),
kwargs.get("half", self.half),
kwargs.get("outscale", self.outscale),
kwargs.get("scale", self.scale),
kwargs.get("pre_pad", self.pre_pad),