Port prefill.cuh to HIP with CDNA3 MMA intrinscis support#16
Port prefill.cuh to HIP with CDNA3 MMA intrinscis support#16diptorupd wants to merge 123 commits intoROCm:amd-integrationfrom
Conversation
…efill_v3' into feature/hipified_prefill_v3
There was a problem hiding this comment.
Pull Request Overview
This PR introduces a HIP port of the single_prefill kernel with CDNA3 MMA intrinsics, specifically focused on adding AMD GPU support to the FlashInfer attention mechanism. The implementation includes comprehensive testing infrastructure for validating GPU kernels against CPU reference implementations.
Key changes include:
- HIP-specific single prefill kernel implementation with CDNA3 MMA intrinsics
- Comprehensive test infrastructure for various attention mechanism components
- CPU reference implementations for validation against GPU kernels
Reviewed Changes
Copilot reviewed 21 out of 23 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| test_prefill_sfrag.sh | Script for automated testing across multiple thread/warp configurations |
| sfrag_tester_script.py | Python utility for parsing and visualizing s_frag debug output from GPU kernels |
| libflashinfer/utils/utils_hip.h | HIP utility functions with deterministic random number generation and new data generation patterns |
| libflashinfer/utils/flashinfer_prefill_ops.hip.h | Main HIP interface for single prefill operations with CDNA3 support |
| libflashinfer/utils/cpu_reference_hip.h | CPU reference implementations for validating GPU attention kernels |
| libflashinfer/tests/hip/test_single_prefill.cpp | Primary test suite for single prefill kernel correctness |
| libflashinfer/tests/hip/test_*.cpp | Various component-specific test suites for memory patterns, MMA operations, and kernel validation |
| libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh | Parameter structures for prefill operations |
| libflashinfer/include/flashinfer/attention/generic/dispatch.cuh | Compile-time dispatch macros for different kernel configurations |
| examples/cpp/standalone_single_prefill.cu | Standalone example demonstrating single prefill usage |
| compile_test.sh | Build script for HIP compilation with appropriate flags |
Comments suppressed due to low confidence (1)
sfrag_tester_script.py:1
- This line contains C++ syntax in a Python file. It should be
return attfor Python.
#!/usr/bin/env python3
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
libflashinfer/utils/utils_hip.h
Outdated
| template <typename T, Predicate Pred> | ||
| void generate_data(std::vector<T>& vec) { | ||
| if constexpr (Pred == Predicate::Linear) { | ||
| assert(vec.size() <= 0); |
There was a problem hiding this comment.
The assertion condition is incorrect. It should check if the vector size is greater than 0 for the Linear predicate case. The current condition vec.size() <= 0 will always trigger an assertion failure for non-empty vectors, preventing the linear data generation from working.
| assert(vec.size() <= 0); | |
| assert(vec.size() > 0); |
776ec56 to
63ab14f
Compare
63ab14f to
28a0355
Compare
|
@diptorupd Good job ! Do you have the plan to upstream the codes to Flashinfer ? |
Please see: flashinfer-ai#1678 |
There was a problem hiding this comment.
Pull Request Overview
Copilot reviewed 17 out of 19 changed files in this pull request and generated 3 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| void vec_normal_(std::vector<T>& vec, float mean = 0.f, float std = 1.f) { | ||
| std::random_device rd{}; | ||
| std::mt19937 gen{rd()}; | ||
| std::mt19937 gen{1234}; |
There was a problem hiding this comment.
Hard-coded seed values (1234) in random number generators make debugging easier but should be configurable for production use. Consider adding a parameter or compile-time flag to control deterministic vs. random seeding.
| std::cout << "DEBUG: Original Q (CPU): " << '\n'; | ||
| for (auto i = 0ul; i < 128; ++i) { | ||
| for (int j = 0; j < 64; ++j) { | ||
| std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; | ||
| } | ||
| std::cout << std::endl; | ||
| } | ||
| std::cout << std::endl; | ||
|
|
||
| std::cout << "DEBUG: Original K (CPU): " << '\n'; | ||
| for (auto i = 0ul; i < 128; ++i) { | ||
| for (int j = 0ul; j < 64; ++j) { | ||
| std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; | ||
| } | ||
| std::cout << std::endl; | ||
| } | ||
| std::cout << std::endl; |
There was a problem hiding this comment.
Large debug blocks with hard-coded dimensions (128, 64) reduce code maintainability. Consider extracting debug functionality into separate functions or using configurable debug levels with template parameters.
| std::cout << "DEBUG: Original Q (CPU): " << '\n'; | |
| for (auto i = 0ul; i < 128; ++i) { | |
| for (int j = 0; j < 64; ++j) { | |
| std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; | |
| } | |
| std::cout << std::endl; | |
| } | |
| std::cout << std::endl; | |
| std::cout << "DEBUG: Original K (CPU): " << '\n'; | |
| for (auto i = 0ul; i < 128; ++i) { | |
| for (int j = 0ul; j < 64; ++j) { | |
| std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; | |
| } | |
| std::cout << std::endl; | |
| } | |
| std::cout << std::endl; | |
| // Debug print helper for Q and K | |
| auto debug_print_matrix = [](const auto& tensor, auto elem_offset_fn, size_t rows, size_t cols, const std::string& name) { | |
| std::cout << "DEBUG: " << name << " (CPU): " << '\n'; | |
| for (size_t i = 0; i < rows; ++i) { | |
| for (size_t j = 0; j < cols; ++j) { | |
| std::cout << (float)tensor[elem_offset_fn(i, 0, j)] << " "; | |
| } | |
| std::cout << std::endl; | |
| } | |
| std::cout << std::endl; | |
| }; | |
| // Print up to 128 rows and 64 columns, or the actual size if smaller | |
| size_t print_rows = std::min(qo_len, size_t(128)); | |
| size_t print_cols = std::min(head_dim, size_t(64)); | |
| debug_print_matrix(q, [&](size_t i, size_t h, size_t j){ return info.get_q_elem_offset(i, h, j); }, print_rows, print_cols, "Original Q"); | |
| debug_print_matrix(k, [&](size_t i, size_t h, size_t j){ return info.get_kv_elem_offset(i, h, j); }, print_rows, print_cols, "Original K"); |
sfrag_tester_script.py
Outdated
| Populate the matrix with values from a specific thread and warp. | ||
|
|
||
| Args: | ||
| matrix: The 128x128 numpy array to populate | ||
| thread_id: The thread ID (0-63) | ||
| warp_id: The warp ID (0-3) | ||
| values: List of 64 float values from this thread |
There was a problem hiding this comment.
Function lacks comprehensive docstring explaining the matrix population algorithm and the relationship between thread/warp IDs and matrix positions. This is critical for understanding the attention mechanism's data layout.
| Populate the matrix with values from a specific thread and warp. | |
| Args: | |
| matrix: The 128x128 numpy array to populate | |
| thread_id: The thread ID (0-63) | |
| warp_id: The warp ID (0-3) | |
| values: List of 64 float values from this thread | |
| Populates the 128x128 matrix with values produced by a specific thread and warp, following the data layout used in attention mechanisms. | |
| The mapping algorithm works as follows: | |
| - Each warp (warp_id: 0-3) is responsible for a contiguous block of 32 rows in the matrix: | |
| * warp 0: rows 0-31 | |
| * warp 1: rows 32-63 | |
| * warp 2: rows 64-95 | |
| * warp 3: rows 96-127 | |
| - Each thread (thread_id: 0-63) within a warp is mapped to a specific set of rows and columns: | |
| * thread_row_base = (thread_id // 16) * 4 | |
| * row_base = warp_row_offset + thread_row_base | |
| * col_base = thread_id % 16 | |
| - Each thread produces 64 values, split into two calls of 32 values each: | |
| * The first 32 values are placed in columns 0-63, starting at (row_base, col_base). | |
| * The second 32 values are placed in columns 64-127, starting at (row_base, col_base). | |
| - The process_call_values function handles the placement of each 32-value block into the matrix. | |
| This mapping ensures that each thread/warp combination fills a unique region of the matrix, matching the parallel computation pattern of attention mechanisms on GPUs. | |
| Args: | |
| matrix: The 128x128 numpy array to populate. | |
| thread_id: The thread ID (0-63). | |
| warp_id: The warp ID (0-3). | |
| values: List of 64 float values produced by this thread. |
|
Superseded by #31 |
…2039) <!-- .github/pull_request_template.md --> ## 📌 Description cuDNN versions specified in CI container setup (`docker/install/install_python_packages.sh`) are currently 9.11 and 9.12. In unit testing, this causes issues as `mm_fp4(backend='cudnn')` is not supported on Spark (sm121) for older cuDNN versions in cu130. Failure is due to cuDNN version shipped with container being too old. In the [latest container build pipeline output](https://github.com/flashinfer-ai/flashinfer/actions/runs/18778064727/job/53577233568#step:6:727), cudnn 9.13.0.50 is installed ``` #16 207.0 Requirement already satisfied: nvidia-cudnn-cu13>=9.12.0.46 in /opt/conda/envs/py312/lib/python3.12/site-packages (9.13.0.50) #16 207.0 Requirement already satisfied: nvidia-cublas in /opt/conda/envs/py312/lib/python3.12/site-packages (from nvidia-cudnn-cu13>=9.12.0.46) (13.0.0.19) ``` Current PR updates the minimum cudnn version for both [cu12](https://pypi.org/project/nvidia-cudnn-cu12/#history) and [cu13](https://pypi.org/project/nvidia-cudnn-cu13/#history) to 9.14.0.64. cudnn 9.13 --> unit test fails with 180 failed, 270 passed, 2790 skipped, 1 warning in 8.97s ``` # pytest tests/gemm/test_mm_fp4.py =================================================================================================================================================== test session starts =================================================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-256] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_ FAILED tests/gemm/test_mm_fp4.py::test_mm_fp4[mxfp4_alpha-False-True-cudnn-res_dtype1-512-512-512] - cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Matmul_MUL_ ================================================================================================================================ 180 failed, 270 passed, 2790 skipped, 1 warning in 8.97s ================================================================================================================================= ``` cudnn 9.14 --> unit test passes with 450 passed, 2790 skipped, 1 warning in 5.37s ``` # pytest tests/gemm/test_mm_fp4.py =================================================================================================================================================== test session starts =================================================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items tests/gemm/test_mm_fp4.py ... ====================================================================================================================================== 450 passed, 2790 skipped, 1 warning in 5.37s ======================================================================================================================================= ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated internal dependencies for improved system stability and compatibility. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
The initial set of changes to our CMake scripts to add support for building ported (hipified) kernels using HIP. --------- Co-authored-by: Diptorup Deb <3046810+diptorupd@users.noreply.github.com>
The initial set of changes to our CMake scripts to add support for building ported (hipified) kernels using HIP. --------- Co-authored-by: Diptorup Deb <3046810+diptorupd@users.noreply.github.com>
The initial set of changes to our CMake scripts to add support for building ported (hipified) kernels using HIP. --------- Co-authored-by: Diptorup Deb <3046810+diptorupd@users.noreply.github.com>
WIP progress port of the single_prefill kernel to HIP with CDNA3 MMA instrinsics.