1
0
Fork 0

feat(api): use ONNX for Real ESRGAN v3 model

This commit is contained in:
Sean Sube 2023-02-17 23:25:42 -06:00
parent 3dde3b9237
commit 2c9d96d2ee
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 25 additions and 22 deletions

View File

@ -37,16 +37,6 @@ def load_resrgan(
if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path)
if x4_v3_tag in model_file:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
elif params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
@ -56,14 +46,25 @@ def load_resrgan(
sess_options=device.sess_options(),
)
elif params.format == "pth":
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
if x4_v3_tag in model_file:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
else:
raise Exception("unknown platform %s" % params.format)

View File

@ -1 +1 @@
from .onnx_net import OnnxImage, OnnxNet
from .onnx_net import OnnxTensor, OnnxNet

View File

@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions
from ..utils import ServerContext
class OnnxImage:
class OnnxTensor:
def __init__(self, source) -> None:
self.source = source
self.data = self
@ -57,8 +57,10 @@ class OnnxNet:
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
return OnnxImage(output)
output = self.session.run([output_name], {
input_name: image.cpu().numpy()
})[0]
return OnnxTensor(output)
def eval(self) -> None:
pass