diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 5fa0df923116..332cb9b4e034 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -89,7 +89,7 @@ COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh RUN bash /install/ubuntu_install_jax.sh "cuda" COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh -RUN bash /install/ubuntu_install_onnx.sh +RUN bash /install/ubuntu_install_onnx.sh "cuda" COPY install/ubuntu_install_libtorch.sh /install/ubuntu_install_libtorch.sh RUN bash /install/ubuntu_install_libtorch.sh diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index a8bebc298810..dc41c39d7c41 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -30,6 +30,9 @@ set -o pipefail # Get the Python version PYTHON_VERSION=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") +# Set default value for first argument +DEVICE=${1:-cpu} + # Install the onnx package pip3 install future @@ -39,28 +42,46 @@ if [ "$PYTHON_VERSION" == "3.9" ]; then onnxruntime==1.19.2 \ onnxoptimizer==0.2.7 - pip3 install \ - torch==2.6.0 \ - torchvision==0.21.0 \ - --extra-index-url https://download.pytorch.org/whl/cpu + if [ "$DEVICE" == "cuda" ]; then + pip3 install \ + torch==2.6.0 \ + torchvision==0.21.0 + else + pip3 install \ + torch==2.6.0 \ + torchvision==0.21.0 \ + --extra-index-url https://download.pytorch.org/whl/cpu + fi elif [ "$PYTHON_VERSION" == "3.11" ]; then pip3 install \ onnx==1.17.0 \ onnxruntime==1.20.1 \ onnxoptimizer==0.2.7 - pip3 install \ - torch==2.6.0 \ - torchvision==0.21.0 \ - --extra-index-url https://download.pytorch.org/whl/cpu + if [ "$DEVICE" == "cuda" ]; then + pip3 install \ + torch==2.6.0 \ + torchvision==0.21.0 + else + pip3 install \ + torch==2.6.0 \ + torchvision==0.21.0 \ + --extra-index-url https://download.pytorch.org/whl/cpu + fi else pip3 install \ onnx==1.12.0 \ onnxruntime==1.12.1 \ onnxoptimizer==0.2.7 - pip3 install \ - torch==2.4.1 \ - torchvision==0.19.1 \ - --extra-index-url https://download.pytorch.org/whl/cpu + if [ "$DEVICE" == "cuda" ]; then + pip3 install \ + torch==2.4.1 \ + torchvision==0.19.1 + else + pip3 install \ + torch==2.4.1 \ + torchvision==0.19.1 \ + --extra-index-url https://download.pytorch.org/whl/cpu + fi fi