File tree Expand file tree Collapse file tree 2 files changed +34
-13
lines changed
Expand file tree Collapse file tree 2 files changed +34
-13
lines changed Original file line number Diff line number Diff line change @@ -89,7 +89,7 @@ COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh
8989RUN bash /install/ubuntu_install_jax.sh "cuda"
9090
9191COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh
92- RUN bash /install/ubuntu_install_onnx.sh
92+ RUN bash /install/ubuntu_install_onnx.sh "cuda"
9393
9494COPY install/ubuntu_install_libtorch.sh /install/ubuntu_install_libtorch.sh
9595RUN bash /install/ubuntu_install_libtorch.sh
Original file line number Diff line number Diff line change @@ -30,6 +30,9 @@ set -o pipefail
3030# Get the Python version
3131PYTHON_VERSION=$( python3 -c " import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')" )
3232
33+ # Set default value for first argument
34+ DEVICE=${1:- cpu}
35+
3336# Install the onnx package
3437pip3 install future
3538
@@ -39,28 +42,46 @@ if [ "$PYTHON_VERSION" == "3.9" ]; then
3942 onnxruntime==1.19.2 \
4043 onnxoptimizer==0.2.7
4144
42- pip3 install \
43- torch==2.6.0 \
44- torchvision==0.21.0 \
45- --extra-index-url https://download.pytorch.org/whl/cpu
45+ if [ " $DEVICE " == " cuda" ]; then
46+ pip3 install \
47+ torch==2.6.0 \
48+ torchvision==0.21.0
49+ else
50+ pip3 install \
51+ torch==2.6.0 \
52+ torchvision==0.21.0 \
53+ --extra-index-url https://download.pytorch.org/whl/cpu
54+ fi
4655elif [ " $PYTHON_VERSION " == " 3.11" ]; then
4756 pip3 install \
4857 onnx==1.17.0 \
4958 onnxruntime==1.20.1 \
5059 onnxoptimizer==0.2.7
5160
52- pip3 install \
53- torch==2.6.0 \
54- torchvision==0.21.0 \
55- --extra-index-url https://download.pytorch.org/whl/cpu
61+ if [ " $DEVICE " == " cuda" ]; then
62+ pip3 install \
63+ torch==2.6.0 \
64+ torchvision==0.21.0
65+ else
66+ pip3 install \
67+ torch==2.6.0 \
68+ torchvision==0.21.0 \
69+ --extra-index-url https://download.pytorch.org/whl/cpu
70+ fi
5671else
5772 pip3 install \
5873 onnx==1.12.0 \
5974 onnxruntime==1.12.1 \
6075 onnxoptimizer==0.2.7
6176
62- pip3 install \
63- torch==2.4.1 \
64- torchvision==0.19.1 \
65- --extra-index-url https://download.pytorch.org/whl/cpu
77+ if [ " $DEVICE " == " cuda" ]; then
78+ pip3 install \
79+ torch==2.4.1 \
80+ torchvision==0.19.1
81+ else
82+ pip3 install \
83+ torch==2.4.1 \
84+ torchvision==0.19.1 \
85+ --extra-index-url https://download.pytorch.org/whl/cpu
86+ fi
6687fi
You can’t perform that action at this time.
0 commit comments