diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 18987572..4060fda6 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -84,7 +84,7 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int): model.to(training_device).train(False) model.eval() - rng = torch.rand(1, 3, 64, 64) + rng = torch.rand(1, 3, 64, 64, device=map_location) input_names = ['data'] output_names = ['output'] dynamic_axes = {'data': {2: 'width', 3: 'height'}, @@ -134,7 +134,7 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int): model.to(training_device).train(False) model.eval() - rng = torch.rand(1, 3, 64, 64) + rng = torch.rand(1, 3, 64, 64, device=map_location) input_names = ['data'] output_names = ['output'] dynamic_axes = {'data': {2: 'width', 3: 'height'},