apply lint
This commit is contained in:
parent
338fc237c7
commit
b3b10b4746
|
@ -46,7 +46,7 @@ def load_resrgan(
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
elif params.format == "pth":
|
elif params.format == "pth":
|
||||||
if TAG_X4_V3 in model_file:
|
if TAG_X4_V3 in model_file:
|
||||||
# the x4-v3 model needs a different network
|
# the x4-v3 model needs a different network
|
||||||
model = SRVGGNetCompact(
|
model = SRVGGNetCompact(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
|
|
|
@ -57,9 +57,7 @@ class OnnxNet:
|
||||||
def __call__(self, image: Any) -> Any:
|
def __call__(self, image: Any) -> Any:
|
||||||
input_name = self.session.get_inputs()[0].name
|
input_name = self.session.get_inputs()[0].name
|
||||||
output_name = self.session.get_outputs()[0].name
|
output_name = self.session.get_outputs()[0].name
|
||||||
output = self.session.run([output_name], {
|
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
|
||||||
input_name: image.cpu().numpy()
|
|
||||||
})[0]
|
|
||||||
return OnnxTensor(output)
|
return OnnxTensor(output)
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
|
|
Loading…
Reference in New Issue