feat(api): use ONNX for Real ESRGAN v3 model
This commit is contained in:
parent
3dde3b9237
commit
2c9d96d2ee
|
@ -37,16 +37,6 @@ def load_resrgan(
|
|||
if not path.isfile(model_path):
|
||||
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
||||
|
||||
if x4_v3_tag in model_file:
|
||||
# the x4-v3 model needs a different network
|
||||
model = SRVGGNetCompact(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_conv=32,
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
)
|
||||
elif params.format == "onnx":
|
||||
# use ONNX acceleration, if available
|
||||
model = OnnxNet(
|
||||
|
@ -56,14 +46,25 @@ def load_resrgan(
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
elif params.format == "pth":
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=params.scale,
|
||||
)
|
||||
if x4_v3_tag in model_file:
|
||||
# the x4-v3 model needs a different network
|
||||
model = SRVGGNetCompact(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_conv=32,
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
)
|
||||
else:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=params.scale,
|
||||
)
|
||||
else:
|
||||
raise Exception("unknown platform %s" % params.format)
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
from .onnx_net import OnnxImage, OnnxNet
|
||||
from .onnx_net import OnnxTensor, OnnxNet
|
||||
|
|
|
@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions
|
|||
from ..utils import ServerContext
|
||||
|
||||
|
||||
class OnnxImage:
|
||||
class OnnxTensor:
|
||||
def __init__(self, source) -> None:
|
||||
self.source = source
|
||||
self.data = self
|
||||
|
@ -57,8 +57,10 @@ class OnnxNet:
|
|||
def __call__(self, image: Any) -> Any:
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
output_name = self.session.get_outputs()[0].name
|
||||
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
|
||||
return OnnxImage(output)
|
||||
output = self.session.run([output_name], {
|
||||
input_name: image.cpu().numpy()
|
||||
})[0]
|
||||
return OnnxTensor(output)
|
||||
|
||||
def eval(self) -> None:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue