Merge branch 'feat/113-onnx-resrgan'
This commit is contained in:
commit
400e579491
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
||||||
last_pipeline_instance = None
|
last_pipeline_instance = None
|
||||||
last_pipeline_params = (None, None)
|
last_pipeline_params = (None, None)
|
||||||
|
|
||||||
x4_v3_tag = "real-esrgan-x4-v3"
|
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||||
|
|
||||||
|
|
||||||
def load_resrgan(
|
def load_resrgan(
|
||||||
|
@ -37,7 +37,16 @@ def load_resrgan(
|
||||||
if not path.isfile(model_path):
|
if not path.isfile(model_path):
|
||||||
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
||||||
|
|
||||||
if x4_v3_tag in model_file:
|
if params.format == "onnx":
|
||||||
|
# use ONNX acceleration, if available
|
||||||
|
model = OnnxNet(
|
||||||
|
server,
|
||||||
|
model_file,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
elif params.format == "pth":
|
||||||
|
if TAG_X4_V3 in model_file:
|
||||||
# the x4-v3 model needs a different network
|
# the x4-v3 model needs a different network
|
||||||
model = SRVGGNetCompact(
|
model = SRVGGNetCompact(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
|
@ -47,15 +56,7 @@ def load_resrgan(
|
||||||
upscale=4,
|
upscale=4,
|
||||||
act_type="prelu",
|
act_type="prelu",
|
||||||
)
|
)
|
||||||
elif params.format == "onnx":
|
else:
|
||||||
# use ONNX acceleration, if available
|
|
||||||
model = OnnxNet(
|
|
||||||
server,
|
|
||||||
model_file,
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
elif params.format == "pth":
|
|
||||||
model = RRDBNet(
|
model = RRDBNet(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
|
@ -68,8 +69,8 @@ def load_resrgan(
|
||||||
raise Exception("unknown platform %s" % params.format)
|
raise Exception("unknown platform %s" % params.format)
|
||||||
|
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
if params.upscale_model == x4_v3_tag and params.denoise != 1:
|
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
|
||||||
wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag))
|
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
|
||||||
model_path = [model_path, wdn_model_path]
|
model_path = [model_path, wdn_model_path]
|
||||||
dni_weight = [params.denoise, 1 - params.denoise]
|
dni_weight = [params.denoise, 1 - params.denoise]
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,15 @@ from os import path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
from .utils import ConversionContext, ModelDict
|
from .utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_upscale_resrgan(
|
def convert_upscale_resrgan(
|
||||||
|
@ -28,6 +31,18 @@ def convert_upscale_resrgan(
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("loading and training model")
|
logger.info("loading and training model")
|
||||||
|
|
||||||
|
if TAG_X4_V3 in name:
|
||||||
|
# the x4-v3 model needs a different network
|
||||||
|
model = SRVGGNetCompact(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_conv=32,
|
||||||
|
upscale=scale,
|
||||||
|
act_type="prelu",
|
||||||
|
)
|
||||||
|
else:
|
||||||
model = RRDBNet(
|
model = RRDBNet(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from .onnx_net import OnnxImage, OnnxNet
|
from .onnx_net import OnnxTensor, OnnxNet
|
||||||
|
|
|
@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
|
|
||||||
class OnnxImage:
|
class OnnxTensor:
|
||||||
def __init__(self, source) -> None:
|
def __init__(self, source) -> None:
|
||||||
self.source = source
|
self.source = source
|
||||||
self.data = self
|
self.data = self
|
||||||
|
@ -58,7 +58,7 @@ class OnnxNet:
|
||||||
input_name = self.session.get_inputs()[0].name
|
input_name = self.session.get_inputs()[0].name
|
||||||
output_name = self.session.get_outputs()[0].name
|
output_name = self.session.get_outputs()[0].name
|
||||||
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
|
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
|
||||||
return OnnxImage(output)
|
return OnnxTensor(output)
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue