@@ -387,30 +387,26 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
387387 if [[ "$CUDA_VERSION" == 12.8* ]]; then
388388 uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
389389 else
390- export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
391- git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
392- # Needed to build AOT kernels
393- (cd flashinfer && \
394- python3 -m flashinfer.aot && \
395- uv pip install --system --no-build-isolation . \
396- )
397- rm -rf flashinfer
398-
399- # Default arches (skipping 10.0a and 12.0 since these need 12.8)
390+ # Exclude CUDA arches for older versions (11.x and 12.0-12.7)
400391 # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
401- TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
402392 if [[ "${CUDA_VERSION}" == 11.* ]]; then
403- TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
393+ FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
394+ elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
395+ FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
396+ else
397+ # CUDA 12.8+ supports 10.0a and 12.0
398+ FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
404399 fi
405- echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST }"
400+ echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST }"
406401
407402 git clone --depth 1 --recursive --shallow-submodules \
408- --branch v0.2.6.post1 \
409- https://github.com/flashinfer-ai/flashinfer.git flashinfer
403+ --branch ${FLASHINFER_GIT_REF} \
404+ ${FLASHINFER_GIT_REPO} flashinfer
410405
406+ # Needed to build AOT kernels
411407 pushd flashinfer
412408 python3 -m flashinfer.aot
413- TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST }" \
409+ TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST }" \
414410 uv pip install --system --no-build-isolation .
415411 popd
416412
0 commit comments