Skip to content

Commit aeadc31

Browse files
authored
[Docker] Use Torch GPU on gpu device (#17676)
1 parent 432ccfa commit aeadc31

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

docker/Dockerfile.ci_gpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh
8989
RUN bash /install/ubuntu_install_jax.sh "cuda"
9090

9191
COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh
92-
RUN bash /install/ubuntu_install_onnx.sh
92+
RUN bash /install/ubuntu_install_onnx.sh "cuda"
9393

9494
COPY install/ubuntu_install_libtorch.sh /install/ubuntu_install_libtorch.sh
9595
RUN bash /install/ubuntu_install_libtorch.sh

docker/install/ubuntu_install_onnx.sh

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ set -o pipefail
3030
# Get the Python version
3131
PYTHON_VERSION=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
3232

33+
# Set default value for first argument
34+
DEVICE=${1:-cpu}
35+
3336
# Install the onnx package
3437
pip3 install future
3538

@@ -39,28 +42,46 @@ if [ "$PYTHON_VERSION" == "3.9" ]; then
3942
onnxruntime==1.19.2 \
4043
onnxoptimizer==0.2.7
4144

42-
pip3 install \
43-
torch==2.6.0 \
44-
torchvision==0.21.0 \
45-
--extra-index-url https://download.pytorch.org/whl/cpu
45+
if [ "$DEVICE" == "cuda" ]; then
46+
pip3 install \
47+
torch==2.6.0 \
48+
torchvision==0.21.0
49+
else
50+
pip3 install \
51+
torch==2.6.0 \
52+
torchvision==0.21.0 \
53+
--extra-index-url https://download.pytorch.org/whl/cpu
54+
fi
4655
elif [ "$PYTHON_VERSION" == "3.11" ]; then
4756
pip3 install \
4857
onnx==1.17.0 \
4958
onnxruntime==1.20.1 \
5059
onnxoptimizer==0.2.7
5160

52-
pip3 install \
53-
torch==2.6.0 \
54-
torchvision==0.21.0 \
55-
--extra-index-url https://download.pytorch.org/whl/cpu
61+
if [ "$DEVICE" == "cuda" ]; then
62+
pip3 install \
63+
torch==2.6.0 \
64+
torchvision==0.21.0
65+
else
66+
pip3 install \
67+
torch==2.6.0 \
68+
torchvision==0.21.0 \
69+
--extra-index-url https://download.pytorch.org/whl/cpu
70+
fi
5671
else
5772
pip3 install \
5873
onnx==1.12.0 \
5974
onnxruntime==1.12.1 \
6075
onnxoptimizer==0.2.7
6176

62-
pip3 install \
63-
torch==2.4.1 \
64-
torchvision==0.19.1 \
65-
--extra-index-url https://download.pytorch.org/whl/cpu
77+
if [ "$DEVICE" == "cuda" ]; then
78+
pip3 install \
79+
torch==2.4.1 \
80+
torchvision==0.19.1
81+
else
82+
pip3 install \
83+
torch==2.4.1 \
84+
torchvision==0.19.1 \
85+
--extra-index-url https://download.pytorch.org/whl/cpu
86+
fi
6687
fi

0 commit comments

Comments
 (0)