diff --git a/README.md b/README.md index 90839608..c94159b2 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Please [see the User Guide](https://github.com/ssube/onnx-web/blob/main/docs/use - [Note about setup paths](#note-about-setup-paths) - [Create a virtual environment](#create-a-virtual-environment) - [Install pip packages](#install-pip-packages) - - [For AMD on Linux: Install ONNX ROCm](#for-amd-on-linux-install-onnx-rocm) + - [For AMD on Linux: Install PyTorch and ONNX ROCm](#for-amd-on-linux-install-pytorch-and-onnx-rocm) - [For AMD on Windows: Install ONNX DirectML](#for-amd-on-windows-install-onnx-directml) - [For CPU on Linux: Install PyTorch CPU](#for-cpu-on-linux-install-pytorch-cpu) - [For CPU on Windows: Install PyTorch CPU](#for-cpu-on-windows-install-pytorch-cpu) @@ -191,11 +191,14 @@ sure you are not using `numpy>=1.24`. [This SO question](https://stackoverflow.com/questions/74844262/how-to-solve-error-numpy-has-no-attribute-float-in-python) has more details. -#### For AMD on Linux: Install ONNX ROCm +#### For AMD on Linux: Install PyTorch and ONNX ROCm -If you are running on Linux with an AMD GPU, download and install the ROCm version of `onnxruntime`: +If you are running on Linux with an AMD GPU, download and install the ROCm version of `onnxruntime` and the ROCm +version of PyTorch: ```shell +> pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2 + > wget https://download.onnxruntime.ai/onnxruntime_training-1.13.0.dev20221021001%2Brocm523-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl > pip install ./onnxruntime_training-1.13.0.dev20221021001%2Brocm523-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -251,9 +254,9 @@ If you are running with an Nvidia GPU on any operating system, install `onnxrunt PyTorch: ```shell -> pip install onnxruntime-gpu +> pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 -> pip install torch --extra-index-url https://download.pytorch.org/whl/cu117 +> pip install onnxruntime-gpu ``` Make sure you have installed CUDA 11.x and that the version of PyTorch matches the version of CUDA