diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index eeda2ebc..18987572 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -53,7 +53,7 @@ extra_models: Models = { model_path = environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models')) training_device = 'cuda' if torch.cuda.is_available() else 'cpu' -map_location = None if torch.cuda.is_available() else torch.device('cpu') +map_location = torch.device(training_device) @torch.no_grad() def convert_real_esrgan(name: str, url: str, scale: int, opset: int):