From a9822679b94e619770938d334814ee0c4c35fd0e Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 23 Dec 2024 14:15:16 +0530 Subject: [PATCH] Prune torch cuda arch list to match upstream --- recipe/bld.bat | 3 ++- recipe/build.sh | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/recipe/bld.bat b/recipe/bld.bat index 30cc5d4f0..b785aa327 100644 --- a/recipe/bld.bat +++ b/recipe/bld.bat @@ -16,7 +16,8 @@ if "%build_with_cuda%" == "" goto cuda_flags_end set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%desired_cuda% set CUDA_BIN_PATH=%CUDA_PATH%\bin -set TORCH_CUDA_ARCH_LIST=5.0;6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX +REM Keep this list in sync with https://github.com/pytorch/pytorch/blob/07fa6e2c8b003319f85a469307f1b1dd73f6026c/.ci/manywheel/build_cuda.sh#L60 +set TORCH_CUDA_ARCH_LIST=5.0;6.0;7.0;7.5;8.0;8.6;9.0+PTX set TORCH_NVCC_FLAGS=-Xfatbin -compress-all :cuda_flags_end diff --git a/recipe/build.sh b/recipe/build.sh index 19da0ad40..851f5c58a 100644 --- a/recipe/build.sh +++ b/recipe/build.sh @@ -168,7 +168,8 @@ elif [[ ${cuda_compiler_version} != "None" ]]; then esac case ${cuda_compiler_version} in 12.6) - export TORCH_CUDA_ARCH_LIST="5.0;6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0+PTX" + # Keep this list in sync with https://github.com/pytorch/pytorch/blob/07fa6e2c8b003319f85a469307f1b1dd73f6026c/.ci/manywheel/build_cuda.sh#L60 + export TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0+PTX" ;; *) echo "unsupported cuda version. edit build.sh"