From abc1ae511245649e3c7498e1564f5330b72c71b0 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 21 Jan 2023 19:59:58 -0600 Subject: [PATCH] fix(api): put conversion RNG on training device (#67) --- api/onnx_web/convert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 18987572..4060fda6 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -84,7 +84,7 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int): model.to(training_device).train(False) model.eval() - rng = torch.rand(1, 3, 64, 64) + rng = torch.rand(1, 3, 64, 64, device=map_location) input_names = ['data'] output_names = ['output'] dynamic_axes = {'data': {2: 'width', 3: 'height'}, @@ -134,7 +134,7 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int): model.to(training_device).train(False) model.eval() - rng = torch.rand(1, 3, 64, 64) + rng = torch.rand(1, 3, 64, 64, device=map_location) input_names = ['data'] output_names = ['output'] dynamic_axes = {'data': {2: 'width', 3: 'height'},