1
0
Fork 0

Merge branch 'feat/113-onnx-resrgan'

This commit is contained in:
Sean Sube 2023-02-18 09:23:36 -06:00
commit 400e579491
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 49 additions and 33 deletions

View File

@ -14,7 +14,7 @@ logger = getLogger(__name__)
last_pipeline_instance = None last_pipeline_instance = None
last_pipeline_params = (None, None) last_pipeline_params = (None, None)
x4_v3_tag = "real-esrgan-x4-v3" TAG_X4_V3 = "real-esrgan-x4-v3"
def load_resrgan( def load_resrgan(
@ -37,17 +37,7 @@ def load_resrgan(
if not path.isfile(model_path): if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path) raise Exception("Real ESRGAN model not found at %s" % model_path)
if x4_v3_tag in model_file: if params.format == "onnx":
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
elif params.format == "onnx":
# use ONNX acceleration, if available # use ONNX acceleration, if available
model = OnnxNet( model = OnnxNet(
server, server,
@ -56,20 +46,31 @@ def load_resrgan(
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
elif params.format == "pth": elif params.format == "pth":
model = RRDBNet( if TAG_X4_V3 in model_file:
num_in_ch=3, # the x4-v3 model needs a different network
num_out_ch=3, model = SRVGGNetCompact(
num_feat=64, num_in_ch=3,
num_block=23, num_out_ch=3,
num_grow_ch=32, num_feat=64,
scale=params.scale, num_conv=32,
) upscale=4,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
else: else:
raise Exception("unknown platform %s" % params.format) raise Exception("unknown platform %s" % params.format)
dni_weight = None dni_weight = None
if params.upscale_model == x4_v3_tag and params.denoise != 1: if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag)) wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
model_path = [model_path, wdn_model_path] model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise] dni_weight = [params.denoise, 1 - params.denoise]

View File

@ -3,12 +3,15 @@ from os import path
import torch import torch
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from torch.onnx import export from torch.onnx import export
from .utils import ConversionContext, ModelDict from .utils import ConversionContext, ModelDict
logger = getLogger(__name__) logger = getLogger(__name__)
TAG_X4_V3 = "real-esrgan-x4-v3"
@torch.no_grad() @torch.no_grad()
def convert_upscale_resrgan( def convert_upscale_resrgan(
@ -28,14 +31,26 @@ def convert_upscale_resrgan(
return return
logger.info("loading and training model") logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3, if TAG_X4_V3 in name:
num_out_ch=3, # the x4-v3 model needs a different network
num_feat=64, model = SRVGGNetCompact(
num_block=23, num_in_ch=3,
num_grow_ch=32, num_out_ch=3,
scale=scale, num_feat=64,
) num_conv=32,
upscale=scale,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(source, map_location=ctx.map_location) torch_model = torch.load(source, map_location=ctx.map_location)
if "params_ema" in torch_model: if "params_ema" in torch_model:

View File

@ -1 +1 @@
from .onnx_net import OnnxImage, OnnxNet from .onnx_net import OnnxTensor, OnnxNet

View File

@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions
from ..utils import ServerContext from ..utils import ServerContext
class OnnxImage: class OnnxTensor:
def __init__(self, source) -> None: def __init__(self, source) -> None:
self.source = source self.source = source
self.data = self self.data = self
@ -58,7 +58,7 @@ class OnnxNet:
input_name = self.session.get_inputs()[0].name input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0] output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
return OnnxImage(output) return OnnxTensor(output)
def eval(self) -> None: def eval(self) -> None:
pass pass