Skip to content

Adds a HIPified version of the SinglePrefillWithKVCacheDevice kernel#31

Merged
diptorupd merged 13 commits intoROCm:amd-integrationfrom
diptorupd:feature/hipified_prefill_v4
Nov 5, 2025
Merged

Adds a HIPified version of the SinglePrefillWithKVCacheDevice kernel#31
diptorupd merged 13 commits intoROCm:amd-integrationfrom
diptorupd:feature/hipified_prefill_v4

Conversation

@diptorupd
Copy link
Collaborator

@diptorupd diptorupd commented Oct 22, 2025

Initial CDNA3 single prefill kernel using WFMA intrinsic.

Ports the SinglePrefillWithKVCacheDevice kernel to HIP along with using CDNA3 MFMA intrinsics. The following kernels have been ported:

  • load_q_global_smem
  • produce_kv
  • compute_qk
  • update_mdo_states
  • compute_sfm_v

Unit test source is /libflashinfer/tests/hip/test_single_prefill.cpp

Know issues:

  1. The HIPification of the write_o_reg_gmem introduced some indexing bugs that are causing the wrong data to be copied back to global memory. As such the single prefill kernel does not yet work end-to-end.
  2. The shared memory swizzle logic was disabled for now and we are using linear indexing without swizzles. The performance fix for bank conflicts will be done as a follow up.

Copilot AI review requested due to automatic review settings October 22, 2025 21:35
@diptorupd diptorupd marked this pull request as draft October 22, 2025 21:35
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements HIPification of the prefill attention kernel, adapting FlashInfer's attention mechanism for AMD's CDNA3 architecture (MI300). The changes include new test infrastructure, utility functions, and modifications to existing code to support AMD GPUs while maintaining compatibility.

Key changes:

  • Added comprehensive test suite for single prefill operations and memory access patterns
  • Introduced utility functions for data generation and deterministic random number generation
  • Implemented CDNA3-specific matrix transpose operations for MFMA tiles
  • Added debug capabilities with configurable thread/warp ID tracking

Reviewed Changes

Copilot reviewed 13 out of 15 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
validate_online_softmax_stateful.py New validation script for stateful online softmax operations
libflashinfer/utils/utils_hip.h Added data generation utilities and fixed random seed for reproducibility
libflashinfer/utils/flashinfer_prefill_ops.hip.h New HIP-specific prefill operation wrappers
libflashinfer/utils/cpu_reference_hip.h Enhanced CPU reference implementation with debug output and soft cap support
libflashinfer/tests/hip/test_single_prefill.cpp Comprehensive test suite for single prefill correctness
libflashinfer/tests/hip/test_q_smem_read_pattern.cpp Test for query shared memory read patterns
libflashinfer/tests/hip/test_k_smem_read_pattern.cpp Test for key shared memory read patterns
libflashinfer/include/gpu_iface/mma_ops.hpp Added transpose_mma_tile API
libflashinfer/include/gpu_iface/backend/hip/mma_hip.h Implemented CDNA3 matrix transpose and refactored rowsum
libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h Enhanced debug utilities with improved formatting and generic fragment writing
libflashinfer/include/flashinfer/attention/generic/page.cuh Removed unused variable
libflashinfer/include/flashinfer/attention/generic/dispatch.cuh New dispatch macros for runtime-to-compile-time constant conversion
libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh Added debug parameters and updated license header

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

void generate_data(std::vector<T>& vec) {
if constexpr (Pred == Predicate::Linear) {
assert(vec.size() > 0);
for (int i = 0; i < vec.size(); i++) {
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use size_t instead of int for the loop counter to match vec.size() return type and avoid signed/unsigned comparison warnings.

Copilot uses AI. Check for mistakes.

template <typename T>
void vec_lexicographic_(std::vector<T>& vec) {
for (int i = 0; i < vec.size(); i++) {
Copy link

Copilot AI Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use size_t instead of int for the loop counter to match vec.size() return type and avoid signed/unsigned comparison warnings.

Suggested change
for (int i = 0; i < vec.size(); i++) {
for (size_t i = 0; i < vec.size(); i++) {

Copilot uses AI. Check for mistakes.
@diptorupd diptorupd force-pushed the feature/hipified_prefill_v4 branch 2 times, most recently from cfa478c to efbe86d Compare November 5, 2025 15:31
@diptorupd diptorupd changed the title Feature/hipified prefill v4 Adds a HIPified version of the SinglePrefillWithKVCacheDevice kernel Nov 5, 2025
@diptorupd diptorupd force-pushed the feature/hipified_prefill_v4 branch 3 times, most recently from f2c175a to e87bdd6 Compare November 5, 2025 15:46
@diptorupd diptorupd marked this pull request as ready for review November 5, 2025 16:20
Copilot AI review requested due to automatic review settings November 5, 2025 16:20
@diptorupd diptorupd force-pushed the feature/hipified_prefill_v4 branch from e87bdd6 to 55d8816 Compare November 5, 2025 16:21
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 11 out of 12 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

libflashinfer/include/gpu_iface/backend/hip/mma_hip.h:1

  • This line was removed but the function m16k16_rowsum_f16f16f32 expects s_frag to be in A-matrix layout. Without calling transpose_intra_quad_fragments, the fragment may not be in the expected layout for the subsequent MFMA operation, potentially causing incorrect results.
// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings November 5, 2025 16:26
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 11 out of 12 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

// apply soft cap if enabled
if (use_soft_cap) {
float soft_cap_pre_tanh_scale = sm_scale / logits_soft_cap;
att[kv_idx] = std::tanh(att[kv_idx] / sm_scale * soft_cap_pre_tanh_scale);
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The soft cap calculation appears incorrect. The value is divided by sm_scale then multiplied by soft_cap_pre_tanh_scale (which is sm_scale / logits_soft_cap), effectively canceling out sm_scale. This should likely be std::tanh(att[kv_idx] * soft_cap_pre_tanh_scale) * logits_soft_cap to properly apply the soft cap transformation.

Suggested change
att[kv_idx] = std::tanh(att[kv_idx] / sm_scale * soft_cap_pre_tanh_scale);
att[kv_idx] = std::tanh(att[kv_idx] * soft_cap_pre_tanh_scale) * logits_soft_cap;

Copilot uses AI. Check for mistakes.
const char* title = "LDS Array (float)") {
if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
printf("%s (%dx%d):\n", title, dimX, dimY);
printf("%s (%dx%d):\n", title, dimY, dimX);
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimension order is reversed in the printf statement. The format string shows dimensions as (dimY x dimX) but the original code intended (dimX x dimY). This was likely changed to match the actual memory layout but creates inconsistency with the parameter names.

Suggested change
printf("%s (%dx%d):\n", title, dimY, dimX);
printf("%s (%dx%d):\n", title, dimX, dimY);

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings November 5, 2025 16:49
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 11 out of 12 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Collaborator

@demandal25 demandal25 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments about dead code/comments

Copilot AI review requested due to automatic review settings November 5, 2025 20:07
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 11 out of 12 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings November 5, 2025 20:13
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 11 out of 12 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings November 5, 2025 20:20
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 12 out of 13 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@@ -1,8 +1,7 @@
// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team
// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc.
// SPDX-FileCopyrightText : 2023-2035 FlashInfer team.
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Year '2035' appears to be a typo; should likely be '2025' to match the current year range.

Suggested change
// SPDX-FileCopyrightText : 2023-2035 FlashInfer team.
// SPDX-FileCopyrightText : 2023-2025 FlashInfer team.

Copilot uses AI. Check for mistakes.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// SPDX-FileCopyrightText : 2023-2035 FlashInfer team.
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Year '2035' appears to be a typo; should likely be '2025' to match the current year range.

Suggested change
// SPDX-FileCopyrightText : 2023-2035 FlashInfer team.
// SPDX-FileCopyrightText : 2023-2025 FlashInfer team.

Copilot uses AI. Check for mistakes.
@@ -1,7 +1,7 @@
// SPDX - FileCopyrightText : 2023-2035 FlashInfer team.
// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc.
// SPDX-FileCopyrightText : 2023-2035 FlashInfer team.
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Year '2035' appears to be a typo; should likely be '2025' to match the current year range.

Suggested change
// SPDX-FileCopyrightText : 2023-2035 FlashInfer team.
// SPDX-FileCopyrightText : 2023-2025 FlashInfer team.

Copilot uses AI. Check for mistakes.
diptorupd and others added 2 commits November 5, 2025 14:51
Ports the SinglePrefillWithKVCacheDevice kernel to HIP along with
using CDNA3 MFMA intrinsics. The following kernels have been ported:

- `load_q_global_smem`
- `produce_kv`
- `compute_qk`
- `update_mdo_states`
- `compute_sfm_v`

Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp`
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
diptorupd and others added 11 commits November 5, 2025 14:51
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
@diptorupd diptorupd force-pushed the feature/hipified_prefill_v4 branch from 964239e to 74f8dbd Compare November 5, 2025 20:51
@demandal25
Copy link
Collaborator

LGTM. Thanks for the PR :)

@demandal25 demandal25 self-requested a review November 5, 2025 21:25
@diptorupd diptorupd merged commit eab214b into ROCm:amd-integration Nov 5, 2025
1 check passed
@diptorupd diptorupd deleted the feature/hipified_prefill_v4 branch November 5, 2025 21:29
diptorupd pushed a commit that referenced this pull request Dec 3, 2025
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP.
Along with some indexing changes and chunking logic required for the
batch prefill (similar to #31), it ports the `page_produce_kv` kernel
that is unique to the batch prefill.

To sanity test the changes, 
- run `python examples/batch_prefill_examples.py` and it should pass all
tests.

**Known issues:**

1. It supports only the `partition_kv=False` case. The port the other
case is WIP.
2. Running the pytest `test_batch_prefill_paged_kernels_hip.py`
currently results in `618 failed, 1710 passed`. We are investigating if
fixing `partition_kv=False` passes the failed ones.

---------

Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
diptorupd added a commit that referenced this pull request Dec 5, 2025
…31)

Initial CDNA3 single prefill kernel using WFMA intrinsic.

Ports the `SinglePrefillWithKVCacheDevice` kernel to HIP along with
using CDNA3 MFMA intrinsics. The following kernels have been ported:

- `load_q_global_smem`
- `produce_kv`
- `compute_qk`
- `update_mdo_states`
- `compute_sfm_v`

Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp`

**Know issues:** 

1. The HIPification of the `write_o_reg_gmem` introduced some indexing
bugs that are causing the wrong data to be copied back to global memory.
As such the single prefill kernel does not yet work end-to-end.
2. The shared memory swizzle logic was disabled for now and we are using
linear indexing without swizzles. The performance fix for bank conflicts
will be done as a follow up.

---------

Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
diptorupd pushed a commit that referenced this pull request Dec 5, 2025
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP.
Along with some indexing changes and chunking logic required for the
batch prefill (similar to #31), it ports the `page_produce_kv` kernel
that is unique to the batch prefill.

To sanity test the changes, 
- run `python examples/batch_prefill_examples.py` and it should pass all
tests.

**Known issues:**

1. It supports only the `partition_kv=False` case. The port the other
case is WIP.
2. Running the pytest `test_batch_prefill_paged_kernels_hip.py`
currently results in `618 failed, 1710 passed`. We are investigating if
fixing `partition_kv=False` passes the failed ones.

---------

Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
zhenhantech pushed a commit to zhenhantech/flashinfer that referenced this pull request Jan 9, 2026
…OCm#31)

Initial CDNA3 single prefill kernel using WFMA intrinsic.

Ports the `SinglePrefillWithKVCacheDevice` kernel to HIP along with
using CDNA3 MFMA intrinsics. The following kernels have been ported:

- `load_q_global_smem`
- `produce_kv`
- `compute_qk`
- `update_mdo_states`
- `compute_sfm_v`

Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp`

**Know issues:** 

1. The HIPification of the `write_o_reg_gmem` introduced some indexing
bugs that are causing the wrong data to be copied back to global memory.
As such the single prefill kernel does not yet work end-to-end.
2. The shared memory swizzle logic was disabled for now and we are using
linear indexing without swizzles. The performance fix for bank conflicts
will be done as a follow up.

---------

Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
zhenhantech pushed a commit to zhenhantech/flashinfer that referenced this pull request Jan 9, 2026
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP.
Along with some indexing changes and chunking logic required for the
batch prefill (similar to ROCm#31), it ports the `page_produce_kv` kernel
that is unique to the batch prefill.

To sanity test the changes, 
- run `python examples/batch_prefill_examples.py` and it should pass all
tests.

**Known issues:**

1. It supports only the `partition_kv=False` case. The port the other
case is WIP.
2. Running the pytest `test_batch_prefill_paged_kernels_hip.py`
currently results in `618 failed, 1710 passed`. We are investigating if
fixing `partition_kv=False` passes the failed ones.

---------

Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
diptorupd added a commit to diptorupd/flashinfer that referenced this pull request Jan 28, 2026
…OCm#31)

Initial CDNA3 single prefill kernel using WFMA intrinsic.

Ports the `SinglePrefillWithKVCacheDevice` kernel to HIP along with
using CDNA3 MFMA intrinsics. The following kernels have been ported:

- `load_q_global_smem`
- `produce_kv`
- `compute_qk`
- `update_mdo_states`
- `compute_sfm_v`

Unit test source is `/libflashinfer/tests/hip/test_single_prefill.cpp`

**Know issues:** 

1. The HIPification of the `write_o_reg_gmem` introduced some indexing
bugs that are causing the wrong data to be copied back to global memory.
As such the single prefill kernel does not yet work end-to-end.
2. The shared memory swizzle logic was disabled for now and we are using
linear indexing without swizzles. The performance fix for bank conflicts
will be done as a follow up.

---------

Signed-off-by: Diptorup Deb <diptorup@cs.unc.edu>
diptorupd pushed a commit to diptorupd/flashinfer that referenced this pull request Jan 28, 2026
This PR ports the BatchPrefillWithPagedKVCacheDevice kernel to HIP.
Along with some indexing changes and chunking logic required for the
batch prefill (similar to ROCm#31), it ports the `page_produce_kv` kernel
that is unique to the batch prefill.

To sanity test the changes, 
- run `python examples/batch_prefill_examples.py` and it should pass all
tests.

**Known issues:**

1. It supports only the `partition_kv=False` case. The port the other
case is WIP.
2. Running the pytest `test_batch_prefill_paged_kernels_hip.py`
currently results in `618 failed, 1710 passed`. We are investigating if
fixing `partition_kv=False` passes the failed ones.

---------

Co-authored-by: Debasis Mandal <debasis.mandal@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants