Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.ci_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh
RUN bash /install/ubuntu_install_jax.sh "cuda"

COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh
RUN bash /install/ubuntu_install_onnx.sh
RUN bash /install/ubuntu_install_onnx.sh "cuda"

COPY install/ubuntu_install_libtorch.sh /install/ubuntu_install_libtorch.sh
RUN bash /install/ubuntu_install_libtorch.sh
Expand Down
45 changes: 33 additions & 12 deletions docker/install/ubuntu_install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ set -o pipefail
# Get the Python version
PYTHON_VERSION=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")

# Set default value for first argument
DEVICE=${1:-cpu}

# Install the onnx package
pip3 install future

Expand All @@ -39,28 +42,46 @@ if [ "$PYTHON_VERSION" == "3.9" ]; then
onnxruntime==1.19.2 \
onnxoptimizer==0.2.7

pip3 install \
torch==2.6.0 \
torchvision==0.21.0 \
--extra-index-url https://download.pytorch.org/whl/cpu
if [ "$DEVICE" == "cuda" ]; then
pip3 install \
torch==2.6.0 \
torchvision==0.21.0
else
pip3 install \
torch==2.6.0 \
torchvision==0.21.0 \
--extra-index-url https://download.pytorch.org/whl/cpu
fi
elif [ "$PYTHON_VERSION" == "3.11" ]; then
pip3 install \
onnx==1.17.0 \
onnxruntime==1.20.1 \
onnxoptimizer==0.2.7

pip3 install \
torch==2.6.0 \
torchvision==0.21.0 \
--extra-index-url https://download.pytorch.org/whl/cpu
if [ "$DEVICE" == "cuda" ]; then
pip3 install \
torch==2.6.0 \
torchvision==0.21.0
else
pip3 install \
torch==2.6.0 \
torchvision==0.21.0 \
--extra-index-url https://download.pytorch.org/whl/cpu
fi
else
pip3 install \
onnx==1.12.0 \
onnxruntime==1.12.1 \
onnxoptimizer==0.2.7

pip3 install \
torch==2.4.1 \
torchvision==0.19.1 \
--extra-index-url https://download.pytorch.org/whl/cpu
if [ "$DEVICE" == "cuda" ]; then
pip3 install \
torch==2.4.1 \
torchvision==0.19.1
else
pip3 install \
torch==2.4.1 \
torchvision==0.19.1 \
--extra-index-url https://download.pytorch.org/whl/cpu
fi
fi
Loading