Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nightly-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ jobs:
- name: Build wheel in container
env:
DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }}
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }}
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }}
FLASHINFER_DEV_RELEASE_SUFFIX: ${{ needs.setup.outputs.dev_suffix }}
run: |
# Extract CUDA major and minor versions
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ jobs:
- name: Build wheel in container
env:
DOCKER_IMAGE: ${{ matrix.arch == 'aarch64' && format('pytorch/manylinuxaarch64-builder:cuda{0}', matrix.cuda) || format('pytorch/manylinux2_28-builder:cuda{0}', matrix.cuda) }}
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 12.0a' }}
FLASHINFER_CUDA_ARCH_LIST: ${{ matrix.cuda == '12.8' && '7.5 8.0 8.9 9.0a 10.0a 12.0a' || '7.5 8.0 8.9 9.0a 10.0a 10.3a 11.0f 12.0f' }}
run: |
# Extract CUDA major and minor versions
CUDA_MAJOR=$(echo "${{ matrix.cuda }}" | cut -d'.' -f1)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ python -m pip install dist/*.whl

`flashinfer-jit-cache` (customize `FLASHINFER_CUDA_ARCH_LIST` for your target GPUs):
```bash
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a"
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f"
cd flashinfer-jit-cache
python -m build --no-isolation --wheel
python -m pip install dist/*.whl
Expand Down
3 changes: 2 additions & 1 deletion csrc/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ constexpr uint32_t cvtExpansion = exactDiv(inputElemSize, cacheElemSize);
constexpr uint32_t preferedKHeadPartBytes = 64;
__constant__ constexpr uint32_t cacheVTileSeqLen = 32;
#else
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || \
__CUDA_ARCH__ == 1100 | __CUDA_ARCH__ == 1200 || __CUDA_ARCH__ == 1210
constexpr uint32_t preferedKHeadPartBytes = 64;
__constant__ constexpr uint32_t cacheVTileSeqLen = 32;
#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 || \
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ You can follow the steps below to install FlashInfer from source code:

.. code-block:: bash

export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 12.0a"
export FLASHINFER_CUDA_ARCH_LIST="7.5 8.0 8.9 10.0a 10.3a 11.0f 12.0f"
cd flashinfer-jit-cache
python -m build --no-isolation --wheel
python -m pip install dist/*.whl
Expand Down
11 changes: 10 additions & 1 deletion scripts/task_test_jit_cache_package_build_import.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ arches = ["7.5", "8.0", "8.9", "9.0a"]
if cuda_ver is not None:
try:
major, minor = map(int, cuda_ver.split(".")[:2])
if (major, minor) >= (12, 8):
if (major, minor) >= (13, 0):
arches.append("10.0a")
arches.append("10.3a")
arches.append("11.0f")
arches.append("12.0f")
elif (major, minor) >= (12, 9):
arches.append("10.0a")
arches.append("10.3a")
arches.append("12.0f")
elif (major, minor) >= (12, 8):
arches.append("10.0a")
arches.append("12.0a")
except Exception:
Expand Down
Loading