chore(api): remove some unused non-ONNX code paths
This commit is contained in:
parent
1bfc7bee32
commit
de61e388a0
|
@ -24,6 +24,7 @@ def correct_codeformer(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
# must be within the load function for patch to take effect
|
||||
# TODO: rewrite and remove
|
||||
from codeformer import CodeFormer
|
||||
|
||||
source = stage_source or source
|
||||
|
|
|
@ -20,6 +20,7 @@ def load_gfpgan(
|
|||
device: DeviceParams,
|
||||
):
|
||||
# must be within the load function for patch to take effect
|
||||
# TODO: rewrite and remove
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import Optional
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..models.rrdb import RRDBNet
|
||||
from ..onnx import OnnxRRDBNet
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
|
@ -21,8 +20,8 @@ def load_resrgan(
|
|||
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||
):
|
||||
# must be within load function for patches to take effect
|
||||
# TODO: rewrite and remove
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
model_file = "%s.%s" % (params.upscale_model, params.format)
|
||||
model_path = path.join(server.model_path, model_file)
|
||||
|
@ -36,36 +35,13 @@ def load_resrgan(
|
|||
if not path.isfile(model_path):
|
||||
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
|
||||
|
||||
if params.format == "onnx":
|
||||
# use ONNX acceleration, if available
|
||||
# TODO: swap for regular RRDBNet after rewriting wrapper
|
||||
model = OnnxRRDBNet(
|
||||
server,
|
||||
model_file,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
elif params.format == "pth":
|
||||
if TAG_X4_V3 in model_file:
|
||||
# the x4-v3 model needs a different network
|
||||
model = SRVGGNetCompact(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_conv=32,
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
)
|
||||
else:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=params.scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown platform %s" % params.format)
|
||||
|
||||
dni_weight = None
|
||||
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
|
||||
|
@ -85,7 +61,7 @@ def load_resrgan(
|
|||
tile=tile,
|
||||
tile_pad=params.tile_pad,
|
||||
pre_pad=params.pre_pad,
|
||||
half=params.half,
|
||||
half=False, # TODO: use server optimizations
|
||||
)
|
||||
|
||||
server.cache.set("resrgan", cache_key, upsampler)
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from torch.onnx import export
|
||||
|
||||
from ...models.rrdb import RRDBNet
|
||||
from ...models.srvgg import SRVGGNetCompact
|
||||
from ..utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -18,8 +19,6 @@ def convert_upscale_resrgan(
|
|||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
scale = model.get("scale")
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
|
||||
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
||||
conducted on the HR feature space.
|
||||
|
||||
Args:
|
||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
||||
num_out_ch (int): Channel number of outputs. Default: 3.
|
||||
num_feat (int): Channel number of intermediate features. Default: 64.
|
||||
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
||||
upscale (int): Upsampling factor. Default: 4.
|
||||
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_conv=16,
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == "relu":
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == "prelu":
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == "leakyrelu":
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == "relu":
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == "prelu":
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == "leakyrelu":
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
|
||||
out += base
|
||||
return out
|
|
@ -234,8 +234,7 @@ class UpscaleParams:
|
|||
faces=True,
|
||||
face_outscale: int = 1,
|
||||
face_strength: float = 0.5,
|
||||
format: Literal["onnx", "pth"] = "onnx",
|
||||
half=False,
|
||||
format: Literal["onnx", "pth"] = "onnx", # TODO: deprecated, remove
|
||||
outscale: int = 1,
|
||||
scale: int = 4,
|
||||
pre_pad: int = 0,
|
||||
|
@ -251,7 +250,6 @@ class UpscaleParams:
|
|||
self.face_outscale = face_outscale
|
||||
self.face_strength = face_strength
|
||||
self.format = format
|
||||
self.half = half
|
||||
self.outscale = outscale
|
||||
self.pre_pad = pre_pad
|
||||
self.scale = scale
|
||||
|
@ -267,7 +265,6 @@ class UpscaleParams:
|
|||
face_outscale=self.face_outscale,
|
||||
face_strength=self.face_strength,
|
||||
format=self.format,
|
||||
half=self.half,
|
||||
outscale=scale,
|
||||
scale=scale,
|
||||
pre_pad=self.pre_pad,
|
||||
|
@ -294,7 +291,6 @@ class UpscaleParams:
|
|||
"face_outscale": self.face_outscale,
|
||||
"face_strength": self.face_strength,
|
||||
"format": self.format,
|
||||
"half": self.half,
|
||||
"outscale": self.outscale,
|
||||
"pre_pad": self.pre_pad,
|
||||
"scale": self.scale,
|
||||
|
@ -311,7 +307,6 @@ class UpscaleParams:
|
|||
kwargs.get("face_outscale", self.face_outscale),
|
||||
kwargs.get("face_strength", self.face_strength),
|
||||
kwargs.get("format", self.format),
|
||||
kwargs.get("half", self.half),
|
||||
kwargs.get("outscale", self.outscale),
|
||||
kwargs.get("scale", self.scale),
|
||||
kwargs.get("pre_pad", self.pre_pad),
|
||||
|
|
Loading…
Reference in New Issue