Adds a HIPified version of the SinglePrefillWithKVCacheDevice kernel#31
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR implements HIPification of the prefill attention kernel, adapting FlashInfer's attention mechanism for AMD's CDNA3 architecture (MI300). The changes include new test infrastructure, utility functions, and modifications to existing code to support AMD GPUs while maintaining compatibility.
Key changes:
- Added comprehensive test suite for single prefill operations and memory access patterns
- Introduced utility functions for data generation and deterministic random number generation
- Implemented CDNA3-specific matrix transpose operations for MFMA tiles
- Added debug capabilities with configurable thread/warp ID tracking
Reviewed Changes
Copilot reviewed 13 out of 15 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| validate_online_softmax_stateful.py | New validation script for stateful online softmax operations |
| libflashinfer/utils/utils_hip.h | Added data generation utilities and fixed random seed for reproducibility |
| libflashinfer/utils/flashinfer_prefill_ops.hip.h | New HIP-specific prefill operation wrappers |
| libflashinfer/utils/cpu_reference_hip.h | Enhanced CPU reference implementation with debug output and soft cap support |
| libflashinfer/tests/hip/test_single_prefill.cpp | Comprehensive test suite for single prefill correctness |
| libflashinfer/tests/hip/test_q_smem_read_pattern.cpp | Test for query shared memory read patterns |
| libflashinfer/tests/hip/test_k_smem_read_pattern.cpp | Test for key shared memory read patterns |
| libflashinfer/include/gpu_iface/mma_ops.hpp | Added transpose_mma_tile API |
| libflashinfer/include/gpu_iface/backend/hip/mma_hip.h | Implemented CDNA3 matrix transpose and refactored rowsum |
| libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h | Enhanced debug utilities with improved formatting and generic fragment writing |
| libflashinfer/include/flashinfer/attention/generic/page.cuh | Removed unused variable |
| libflashinfer/include/flashinfer/attention/generic/dispatch.cuh | New dispatch macros for runtime-to-compile-time constant conversion |
| libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh | Added debug parameters and updated license header |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| void generate_data(std::vector<T>& vec) { | ||
| if constexpr (Pred == Predicate::Linear) { | ||
| assert(vec.size() > 0); | ||
| for (int i = 0; i < vec.size(); i++) { |
There was a problem hiding this comment.
Use size_t instead of int for the loop counter to match vec.size() return type and avoid signed/unsigned comparison warnings.
|
|
||
| template <typename T> | ||
| void vec_lexicographic_(std::vector<T>& vec) { | ||
| for (int i = 0; i < vec.size(); i++) { |
There was a problem hiding this comment.
Use size_t instead of int for the loop counter to match vec.size() return type and avoid signed/unsigned comparison warnings.
| for (int i = 0; i < vec.size(); i++) { | |
| for (size_t i = 0; i < vec.size(); i++) { |
libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh
Show resolved
Hide resolved
cfa478c to
efbe86d
Compare
f2c175a to
e87bdd6
Compare
e87bdd6 to
55d8816
Compare
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 11 out of 12 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (1)
libflashinfer/include/gpu_iface/backend/hip/mma_hip.h:1
- This line was removed but the function
m16k16_rowsum_f16f16f32expectss_fragto be in A-matrix layout. Without callingtranspose_intra_quad_fragments, the fragment may not be in the expected layout for the subsequent MFMA operation, potentially causing incorrect results.
// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
libflashinfer/include/flashinfer/attention/generic/dispatch.cuh
Outdated
Show resolved
Hide resolved
libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 11 out of 12 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // apply soft cap if enabled | ||
| if (use_soft_cap) { | ||
| float soft_cap_pre_tanh_scale = sm_scale / logits_soft_cap; | ||
| att[kv_idx] = std::tanh(att[kv_idx] / sm_scale * soft_cap_pre_tanh_scale); |
There was a problem hiding this comment.
The soft cap calculation appears incorrect. The value is divided by sm_scale then multiplied by soft_cap_pre_tanh_scale (which is sm_scale / logits_soft_cap), effectively canceling out sm_scale. This should likely be std::tanh(att[kv_idx] * soft_cap_pre_tanh_scale) * logits_soft_cap to properly apply the soft cap transformation.
| att[kv_idx] = std::tanh(att[kv_idx] / sm_scale * soft_cap_pre_tanh_scale); | |
| att[kv_idx] = std::tanh(att[kv_idx] * soft_cap_pre_tanh_scale) * logits_soft_cap; |
| const char* title = "LDS Array (float)") { | ||
| if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { | ||
| printf("%s (%dx%d):\n", title, dimX, dimY); | ||
| printf("%s (%dx%d):\n", title, dimY, dimX); |
There was a problem hiding this comment.
Dimension order is reversed in the printf statement. The format string shows dimensions as (dimY x dimX) but the original code intended (dimX x dimY). This was likely changed to match the actual memory layout but creates inconsistency with the parameter names.
| printf("%s (%dx%d):\n", title, dimY, dimX); | |
| printf("%s (%dx%d):\n", title, dimX, dimY); |
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 11 out of 12 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
demandal25
left a comment
There was a problem hiding this comment.
Left some comments about dead code/comments
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 11 out of 12 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 11 out of 12 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 12 out of 13 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -1,8 +1,7 @@ | |||
| // SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team | |||
| // SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. | |||
| // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. | |||
There was a problem hiding this comment.
Year '2035' appears to be a typo; should likely be '2025' to match the current year range.
| // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. | |
| // SPDX-FileCopyrightText : 2023-2025 FlashInfer team. |
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. |
There was a problem hiding this comment.
Year '2035' appears to be a typo; should likely be '2025' to match the current year range.
| // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. | |
| // SPDX-FileCopyrightText : 2023-2025 FlashInfer team. |
| @@ -1,7 +1,7 @@ | |||
| // SPDX - FileCopyrightText : 2023-2035 FlashInfer team. | |||
| // SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. | |||
| // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. | |||
There was a problem hiding this comment.
Year '2035' appears to be a typo; should likely be '2025' to match the current year range.
| // SPDX-FileCopyrightText : 2023-2035 FlashInfer team. | |
| // SPDX-FileCopyrightText : 2023-2025 FlashInfer team. |
Ports the SinglePrefillWithKVCacheDevice kernel to HIP along with using CDNA3 MFMA intrinsics. The following kernels have been ported: - `load_q_global_smem` - `produce_kv` - `compute_qk` - `update_mdo_states` - `compute_sfm_v` Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp`
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
964239e to
74f8dbd
Compare
|
LGTM. Thanks for the PR :) |
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP. Along with some indexing changes and chunking logic required for the batch prefill (similar to #31), it ports the `page_produce_kv` kernel that is unique to the batch prefill. To sanity test the changes, - run `python examples/batch_prefill_examples.py` and it should pass all tests. **Known issues:** 1. It supports only the `partition_kv=False` case. The port the other case is WIP. 2. Running the pytest `test_batch_prefill_paged_kernels_hip.py` currently results in `618 failed, 1710 passed`. We are investigating if fixing `partition_kv=False` passes the failed ones. --------- Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
…31) Initial CDNA3 single prefill kernel using WFMA intrinsic. Ports the `SinglePrefillWithKVCacheDevice` kernel to HIP along with using CDNA3 MFMA intrinsics. The following kernels have been ported: - `load_q_global_smem` - `produce_kv` - `compute_qk` - `update_mdo_states` - `compute_sfm_v` Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp` **Know issues:** 1. The HIPification of the `write_o_reg_gmem` introduced some indexing bugs that are causing the wrong data to be copied back to global memory. As such the single prefill kernel does not yet work end-to-end. 2. The shared memory swizzle logic was disabled for now and we are using linear indexing without swizzles. The performance fix for bank conflicts will be done as a follow up. --------- Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP. Along with some indexing changes and chunking logic required for the batch prefill (similar to #31), it ports the `page_produce_kv` kernel that is unique to the batch prefill. To sanity test the changes, - run `python examples/batch_prefill_examples.py` and it should pass all tests. **Known issues:** 1. It supports only the `partition_kv=False` case. The port the other case is WIP. 2. Running the pytest `test_batch_prefill_paged_kernels_hip.py` currently results in `618 failed, 1710 passed`. We are investigating if fixing `partition_kv=False` passes the failed ones. --------- Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
…OCm#31) Initial CDNA3 single prefill kernel using WFMA intrinsic. Ports the `SinglePrefillWithKVCacheDevice` kernel to HIP along with using CDNA3 MFMA intrinsics. The following kernels have been ported: - `load_q_global_smem` - `produce_kv` - `compute_qk` - `update_mdo_states` - `compute_sfm_v` Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp` **Know issues:** 1. The HIPification of the `write_o_reg_gmem` introduced some indexing bugs that are causing the wrong data to be copied back to global memory. As such the single prefill kernel does not yet work end-to-end. 2. The shared memory swizzle logic was disabled for now and we are using linear indexing without swizzles. The performance fix for bank conflicts will be done as a follow up. --------- Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP. Along with some indexing changes and chunking logic required for the batch prefill (similar to ROCm#31), it ports the `page_produce_kv` kernel that is unique to the batch prefill. To sanity test the changes, - run `python examples/batch_prefill_examples.py` and it should pass all tests. **Known issues:** 1. It supports only the `partition_kv=False` case. The port the other case is WIP. 2. Running the pytest `test_batch_prefill_paged_kernels_hip.py` currently results in `618 failed, 1710 passed`. We are investigating if fixing `partition_kv=False` passes the failed ones. --------- Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
…OCm#31) Initial CDNA3 single prefill kernel using WFMA intrinsic. Ports the `SinglePrefillWithKVCacheDevice` kernel to HIP along with using CDNA3 MFMA intrinsics. The following kernels have been ported: - `load_q_global_smem` - `produce_kv` - `compute_qk` - `update_mdo_states` - `compute_sfm_v` Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp` **Know issues:** 1. The HIPification of the `write_o_reg_gmem` introduced some indexing bugs that are causing the wrong data to be copied back to global memory. As such the single prefill kernel does not yet work end-to-end. 2. The shared memory swizzle logic was disabled for now and we are using linear indexing without swizzles. The performance fix for bank conflicts will be done as a follow up. --------- Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP. Along with some indexing changes and chunking logic required for the batch prefill (similar to ROCm#31), it ports the `page_produce_kv` kernel that is unique to the batch prefill. To sanity test the changes, - run `python examples/batch_prefill_examples.py` and it should pass all tests. **Known issues:** 1. It supports only the `partition_kv=False` case. The port the other case is WIP. 2. Running the pytest `test_batch_prefill_paged_kernels_hip.py` currently results in `618 failed, 1710 passed`. We are investigating if fixing `partition_kv=False` passes the failed ones. --------- Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
Initial CDNA3 single prefill kernel using WFMA intrinsic.
Ports the
SinglePrefillWithKVCacheDevicekernel to HIP along with using CDNA3 MFMA intrinsics. The following kernels have been ported:load_q_global_smemproduce_kvcompute_qkupdate_mdo_statescompute_sfm_vUnit test source is
/libflashinfer/tests/hip/test_single_prefill.cppKnow issues:
write_o_reg_gmemintroduced some indexing bugs that are causing the wrong data to be copied back to global memory. As such the single prefill kernel does not yet work end-to-end.