diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 46061c91..c9f11c0d 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -29,6 +29,9 @@ class ONNXImage(): self.source = source self.data = self + def __getitem__(self, *args): + return torch.from_numpy(self.source.__getitem__(*args)).to(torch.float32) + def squeeze(self): self.source = np.squeeze(self.source, (0)) return self