1
0
Fork 0

feat(api): use ONNX for Real ESRGAN v3 model

This commit is contained in:
Sean Sube 2023-02-17 23:25:42 -06:00
parent 3dde3b9237
commit 2c9d96d2ee
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 25 additions and 22 deletions

View File

@ -37,6 +37,15 @@ def load_resrgan(
if not path.isfile(model_path): if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path) raise Exception("Real ESRGAN model not found at %s" % model_path)
elif params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
server,
model_file,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
elif params.format == "pth":
if x4_v3_tag in model_file: if x4_v3_tag in model_file:
# the x4-v3 model needs a different network # the x4-v3 model needs a different network
model = SRVGGNetCompact( model = SRVGGNetCompact(
@ -47,15 +56,7 @@ def load_resrgan(
upscale=4, upscale=4,
act_type="prelu", act_type="prelu",
) )
elif params.format == "onnx": else:
# use ONNX acceleration, if available
model = OnnxNet(
server,
model_file,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
elif params.format == "pth":
model = RRDBNet( model = RRDBNet(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,

View File

@ -1 +1 @@
from .onnx_net import OnnxImage, OnnxNet from .onnx_net import OnnxTensor, OnnxNet

View File

@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions
from ..utils import ServerContext from ..utils import ServerContext
class OnnxImage: class OnnxTensor:
def __init__(self, source) -> None: def __init__(self, source) -> None:
self.source = source self.source = source
self.data = self self.data = self
@ -57,8 +57,10 @@ class OnnxNet:
def __call__(self, image: Any) -> Any: def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0] output = self.session.run([output_name], {
return OnnxImage(output) input_name: image.cpu().numpy()
})[0]
return OnnxTensor(output)
def eval(self) -> None: def eval(self) -> None:
pass pass