Skip to content

feat: add GDN Attention#2276

Merged
yzh119 merged 14 commits intomainfrom
feat/GDNAttention
Jan 3, 2026
Merged

feat: add GDN Attention#2276
yzh119 merged 14 commits intomainfrom
feat/GDNAttention

Conversation

@guangyunh-nv
Copy link
Collaborator

@guangyunh-nv guangyunh-nv commented Dec 31, 2025

📌 Description

This PR adds implementation for Gated Delta Rule (or Gated Delta Net) on Hopper architecture to better support Qwen-next like architecture.

🔍 Related Issues

#1690

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Thanks @jiahanc for initiating the kernel integration and implementing the API.

Summary by CodeRabbit

  • New Features

    • SM90-optimized Gated Delta Rule (GDN) prefill: high-level Python API (chunk_gated_delta_rule), host launcher, and FFI export; optional alpha/beta gating and final-state output.
  • Benchmarks & Tests

    • New GPU benchmark for GDN prefill reporting runtime, TFLOPs and bandwidth.
    • Added reference implementations and comprehensive tests covering prefill, chunked prefill, and delta-rule correctness.

✏️ Tip: You can customize this high-level summary in your review settings.

jiahanc and others added 10 commits December 12, 2025 01:25
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 31, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds an SM90-targeted Gated Delta Rule (GDN) prefill: new CUDA kernels and launchers, Flat/CUTE/TMA primitives and helpers, Python JIT integration and runtime bindings, a benchmark driver, and tests with Python reference implementations.

Changes

Cohort / File(s) Summary
Benchmark
\benchmarks/bench_gdn_prefill.py``
New benchmark driver estimating FLOPs/bytes, running CUDA timings, and CLI presets for GDN prefill.
Python API & JIT
\flashinfer/gdn_prefill.py`, `flashinfer/init.py`, `flashinfer/jit/gdn.py`, `flashinfer/aot.py``
JIT spec generator, module loader, exported chunk_gated_delta_rule wrapper, and AOT inclusion for SM90 module.
Host launcher / TVM FFI
\csrc/gdn_prefill_launcher.cu``
Host dispatch and TVM-exposed gdn_prefill entrypoint with dtype/SM90 guards and calls into flat launchers.
Prefill kernel declarations & SM90 launchers
\csrc/flat/prefill/prefill_kernel.hpp`, `.../prefill_kernel_delta_rule_sm90.cuh`, `.../prefill_kernel_delta_rule_sm90.cu``
Templated SM90 launcher selection (GVA/GQA, alpha/beta, init-state permutations) and explicit FP16/BF16 instantiations.
Flat foundational utilities
\csrc/flat/common.hpp`, `csrc/flat/debug.hpp`, `csrc/flat/cute_ext.hpp`, `csrc/flat/unused.hpp`, `csrc/flat/type_traits.hpp`, `csrc/flat/math.hpp``
New low-level helpers: error/CUDA checks, debug macros, CUTE layout/tensor utilities, Unused placeholder, type-trait mappings, and math helpers.
Collective / GMMA & layout utilities
\csrc/flat/collective/... `, `csrc/flat/ampere/collective/flat_collective_inverse.hpp`, `csrc/flat/hopper/collective/flat_common.hpp``
GMMA/GMMA-RS converters, accumulator→operand transforms, AMPERE inverse-collective, and collective layout/layout-conversion helpers.
TMA Load/Store primitives
\csrc/flat/collective/flat_collective_load.hpp`, `csrc/flat/hopper/collective/flat_collective_load.hpp`, `csrc/flat/hopper/collective/flat_collective_store.hpp``
New pipeline-aware CollectiveLoad/CollectiveStore (TMA) primitives with partition/step logic and tail/vector handling.
Warp-specialized mainloop & kernels
\csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp`, `csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp`, `csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp``
Large new mainloop and kernel specializations implementing warp/TMA delta-rule pipelines, MMAs, synchronization, and compute/store orchestration.
Kernel infra: scheduler/options/universal
\csrc/flat/hopper/kernel/flat_tile_scheduler.hpp`, `csrc/flat/hopper/kernel/flat_options.hpp`, `csrc/flat/hopper/device/device_universal.hpp`, `csrc/flat/hopper/collective/flat_named_barriers.hpp`, `csrc/flat/math_order_barrier.hpp``
New tile scheduler, tag-based options system, Universal kernel wrapper, named-barrier definitions, and ordered named-barrier utility.
Prefill workspace / store helpers
\csrc/flat/hopper/collective/flat_collective_load.hpp`, `csrc/flat/hopper/collective/flat_collective_store.hpp`, `csrc/flat/collective/flat_collective_inverse.hpp``
Additional SM80/SM90-specific store, layout conversion, and blockwise inverse utilities supporting prefill.
Tests / Reference implementations
\tests/gdn/conftest.py`, `tests/gdn/reference_delta_rule.py`, `tests/gdn/test_prefill_delta_rule.py``
SM90a test guard and fixtures, reference linear/delta-rule implementations, and parametrized tests including chunked prefill validation.

Sequence Diagram(s)

sequenceDiagram
    participant User as Python Caller
    participant PyModule as flashinfer.gdn_prefill
    participant JIT as JIT / nvcc
    participant Launcher as gdn_prefill_launcher (host)
    participant Kernel as SM90 Delta-rule Kernel
    participant TMA as GPU TMA/SMEM

    User->>PyModule: chunk_gated_delta_rule(q,k,v,...)
    PyModule->>JIT: ensure SM90 module loaded (gen_gdn_prefill_sm90_module)
    PyModule->>Launcher: call gdn_prefill_launcher with pointers & flags
    Launcher->>Kernel: select template variant and launch kernel
    Kernel->>TMA: TMA loads/stores, MMAs, barriers (compute/store)
    Kernel-->>Launcher: kernel completes (device memory populated)
    Launcher-->>PyModule: return control
    PyModule-->>User: return output (and state if requested)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • jimmyzho
  • aleozlx
  • jiahanc
  • Anerudhan
  • cyx-6

Poem

🐰 I hopped through tiles and named-barrier lanes,
I shuffled bytes and tuned the TMA trains,
From Python call to SM90's bright light,
QKV danced and kernels hummed through the night,
A rabbit cheers — the prefill ran just right!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.08% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title 'feat: add GDN Attention' accurately describes the main change—implementing Gated Delta Rule (GDN) attention on Hopper. It is concise, clear, and specific enough for a developer scanning history to understand the primary contribution.
Description check ✅ Passed The PR description covers the essential sections: a brief Description explaining GDN implementation for Hopper to support Qwen-next architectures, a Related Issues link (#1690), completed Pre-commit Checks, and Reviewer Notes acknowledging contributors. While the Tests section is incomplete (items unchecked), the core required content is present.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @guangyunh-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the Gated Delta Rule (GDN) attention mechanism, crucial for modern large language models like Qwen-next. It delivers a high-performance CUDA kernel optimized for Hopper GPUs, exposed through a user-friendly Python interface. This enhancement significantly expands the library's capabilities for advanced attention architectures, ensuring both efficiency and accuracy through rigorous testing and benchmarking.

Highlights

  • Gated Delta Rule (GDN) Attention: Introduces a new implementation for Gated Delta Rule (GDN) attention, also known as Gated Delta Net, designed to support Qwen-next like architectures.
  • CUDA Kernel Implementation: Adds highly optimized CUDA kernels for GDN prefill, leveraging SM90 (Hopper) architecture features like TMA and warp-specialized collectives for efficient computation and memory access.
  • Python API and Benchmarking: Provides a Python API chunk_gated_delta_rule for easy integration and includes a dedicated benchmarking script to measure performance (TFLOPS, TB/s) of the GDN prefill kernel.
  • Comprehensive Testing: Includes a Python reference implementation of the Gated Delta Rule and extensive Pytest tests to ensure numerical correctness for various configurations, including chunked prefill and different head groupings (GQA/GVA).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Gated Delta Rule (GDN) prefill kernel for FlashInfer, targeting NVIDIA Hopper (SM90) architecture. The changes include a Python benchmark script (bench_gdn_prefill.py) to measure performance metrics like TFLOPS and memory bandwidth for the GDN kernel, along with new C++/CUDA files that implement the core GDN logic. The C++ implementation leverages CUTLASS and cute libraries for efficient GPU computations, including collective loads for Q, K, V, alpha, and beta tensors using TMA, a blockwise matrix inverse for state updates, and collective stores for the output. The kernel supports both Grouped Query Attention (GQA) and Grouped Value Attention (GVA) configurations, handles initial state loading, and applies alpha/beta gating. A Python API (flashinfer/gdn_prefill.py) is added to expose the chunk_gated_delta_rule function, which orchestrates the kernel launch and manages tensor allocations. Unit tests (tests/gdn/) and reference implementations are also included to verify correctness, with specific tests for basic and non-full block scenarios, as well as chunked prefill. Review comments highlighted a critical bug in the CHECK macro's stringification and argument passing, dead code in bench_gdn_prefill.py related to peak bandwidth calculation, and an unused use_qk_l2norm_in_kernel parameter in the chunk_gated_delta_rule Python API.

Comment on lines +11 to +18
#define CHECK(expr, msg) \
do { \
if (!(expr)) { \
std::string buffer(1024, '\0'); \
sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \
throw std::runtime_error(buffer.c_str()); \
} \
} while (0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The CHECK macro is implemented incorrectly. The use of ##expr is not standard for stringification and will likely cause a compilation error. It should be #expr. Additionally, the sprintf call has a mismatch between the number of format specifiers and the number of arguments provided, which leads to undefined behavior. The msg and __FILE__ arguments are also not passed correctly. This is a critical issue in a utility macro that could hide bugs or cause crashes.

Suggested change
#define CHECK(expr, msg) \
do { \
if (!(expr)) { \
std::string buffer(1024, '\0'); \
sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \
throw std::runtime_error(buffer.c_str()); \
} \
} while (0)
#define CHECK(expr, msg) \
do { \
if (!(expr)) { \
std::string buffer(1024, '\0'); \
sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", #expr, msg, __FILE__, __LINE__); \
throw std::runtime_error(buffer.c_str()); \
} \
} while (0)


# Get device info for bandwidth calculation
props = torch.cuda.get_device_properties(0)
props.total_memory * 2 / 1e12 # Approximate peak bandwidth
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation on this line is not used. The result is not assigned to any variable or used later in the function. This appears to be dead code and should be removed to avoid confusion.

initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter use_qk_l2norm_in_kernel is defined in the function signature and its docstring, but it is not used within the function body. The underlying custom op gdn_prefill is not called with this parameter. This can be misleading for users of the API. This parameter should either be implemented or removed to avoid confusion. The benchmark and test files perform L2 normalization on k in Python before calling this function, which suggests this parameter is indeed unused.

@yzh119 yzh119 marked this pull request as ready for review January 1, 2026 01:20
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

♻️ Duplicate comments (3)
csrc/flat/common.hpp (1)

11-18: Fix the CHECK macro: stringification and missing comma.

This was previously flagged. The macro has two issues:

  1. ##expr should be #expr for proper stringification
  2. Missing comma between msg and __FILE__
🔎 Proposed fix
 #define CHECK(expr, msg)                                                                           \
   do {                                                                                             \
     if (!(expr)) {                                                                                 \
       std::string buffer(1024, '\0');                                                              \
-      sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \
+      sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", #expr, msg, __FILE__, __LINE__); \
       throw std::runtime_error(buffer.c_str());                                                    \
     }                                                                                              \
   } while (0)
flashinfer/gdn_prefill.py (1)

89-89: Unused parameter use_qk_l2norm_in_kernel.

This parameter is documented but not passed to the kernel. This was already flagged in a previous review.

benchmarks/bench_gdn_prefill.py (1)

199-201: Remove dead code.

This expression computes a value but doesn't assign or use it. This was already flagged in a previous review.

🧹 Nitpick comments (21)
csrc/flat/math.hpp (1)

9-19: Consider documenting edge case behavior for n=0.

The current implementation returns 1 for next_power_of_two(0) since ceil_log2(0) returns 0. This may be intentional, but could be unexpected. Consider adding a brief comment or static_assert if n > 0 is a precondition.

csrc/flat/cute_ext.hpp (1)

8-8: Avoid using namespace in header files.

using namespace cute; in a header pollutes the namespace of all translation units that include this header, potentially causing name collisions. Consider using explicit cute:: prefixes or limiting the directive to a narrower scope.

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (1)

119-133: Include status details in error messages for easier debugging.

The generic error messages ("can_implement failed", "initialize failed", "run failed") don't provide insight into why the operation failed. Consider including the cutlass::Status value or its string representation.

🔎 Proposed improvement
     status = op.can_implement(arguments);
     if (status != cutlass::Status::kSuccess) {
-      throw std::runtime_error("can_implement failed");
+      throw std::runtime_error(std::string("can_implement failed: ") +
+                               cutlass::cutlassGetStatusString(status));
     }
 
     status = op.initialize(arguments, workspace.get(), stream);
     if (status != cutlass::Status::kSuccess) {
-      throw std::runtime_error("initialize failed");
+      throw std::runtime_error(std::string("initialize failed: ") +
+                               cutlass::cutlassGetStatusString(status));
     }
 
     status = op.run(stream);
     if (status != cutlass::Status::kSuccess) {
-      throw std::runtime_error("run failed");
+      throw std::runtime_error(std::string("run failed: ") +
+                               cutlass::cutlassGetStatusString(status));
     }
csrc/flat/ampere/collective/flat_collective_load.hpp (1)

10-10: Avoid using namespace in header files.

Similar to cute_ext.hpp, this pollutes the namespace for all includers. Consider using explicit cute:: prefixes.

csrc/flat/ampere/collective/flat_collective_inverse.hpp (1)

10-10: Avoid using namespace in header files.

Consistent with other headers in this PR, this introduces namespace pollution for all includers.

flashinfer/jit/gdn.py (1)

25-37: Missing supported_major_versions to restrict compilation to SM90+.

Per coding guidelines, JitSpec should specify supported_major_versions to restrict kernel compilation to supported GPU architectures. Since this is an SM90-specific module, it should explicitly declare this constraint.

🔎 Proposed fix
 def gen_gdn_prefill_sm90_module() -> JitSpec:
     return gen_jit_spec(
         name="gdn_prefill_launcher",
         sources=[
             jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu",
             jit_env.FLASHINFER_CSRC_DIR
             / "flat"
             / "prefill"
             / "prefill_kernel_delta_rule_sm90.cu",
         ],
-        extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
+        extra_cuda_cflags=[*sm90a_nvcc_flags, "-DFLAT_SM90A_ENABLED", "-std=c++20"],
         extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
+        supported_major_versions=[9, 10, 11, 12],
     )

Based on learnings, supported_major_versions should be specified in JitSpec for SM-specific kernels.

tests/gdn/conftest.py (1)

34-35: Consider removing commented-out code.

The commented-out multidist_randn line appears to be debug/development artifact. Consider removing it to keep the codebase clean.

csrc/flat/math_order_barrier.hpp (1)

31-34: Track the FIXME for persistent scheduler support.

The destructor contains a FIXME indicating that the current implementation will have issues with a persistent scheduler. This should be tracked for future work.

Would you like me to open an issue to track the persistent scheduler compatibility work for OrderedNamedBarriers?

csrc/flat/debug.hpp (1)

22-37: DPRINTF_W and DPRINTF_WG have identical implementations.

Both macros print identical output with [WG%d][W%d][T%-3d] prefix and the same condition IS_PRINT_BLOCK. This appears to be unintentional duplication. If they're meant to serve different purposes (e.g., warp-level vs warp-group-level filtering), consider differentiating them.

Suggested differentiation

If DPRINTF_WG is intended to be the warp-group variant, perhaps filter to only print from warp-group leaders:

 #define DPRINTF_WG(fmt, ...)                                                                       \
-  if (IS_PRINT_BLOCK)                                                                              \
+  if (IS_PRINT_BLOCK && threadIdx.x % 128 < 32)                                                    \
   printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \
          threadIdx.x, ##__VA_ARGS__)
flashinfer/gdn_prefill.py (1)

176-182: Consider simplifying output_state allocation logic.

Lines 170-182 allocate output_state whether or not output_final_state is True. The conditional on line 176 (elif not output_final_state and output_state is None) could be merged with the first condition since both branches do the same allocation.

Simplified allocation
-    # Allocate output_state if needed
-    if output_final_state and output_state is None:
-        output_state = torch.empty(
-            (num_seqs, num_sab_heads, head_size, head_size),
-            dtype=torch.float32,
-            device=q.device,
-        )
-    elif not output_final_state and output_state is None:
-        # Still need to allocate since kernel always writes state
+    # Allocate output_state if not provided (kernel always writes state)
+    if output_state is None:
         output_state = torch.empty(
             (num_seqs, num_sab_heads, head_size, head_size),
             dtype=torch.float32,
             device=q.device,
         )
tests/gdn/test_prefill_delta_rule.py (1)

23-28: Unused block_size parameter.

The block_size parameter is defined but never used in _test_prefill_kernel. If it's intended for future use, consider adding a TODO comment. Otherwise, remove it.

benchmarks/bench_gdn_prefill.py (1)

25-56: Consider removing unused parameters from gdn_flops signature.

num_k_heads and num_seqs are not used in the FLOPs calculation. While keeping them for API consistency with gdn_bytes is reasonable, documenting why they're unused would be helpful.

Option 1: Add underscore prefix to indicate intentionally unused
 def gdn_flops(
     total_seq_len: int,
     num_q_heads: int,
-    num_k_heads: int,
+    _num_k_heads: int,  # unused, kept for API consistency
     num_v_heads: int,
     head_size: int,
-    num_seqs: int,
+    _num_seqs: int,  # unused, state ops not counted in TFLOPS
 ) -> int:
csrc/gdn_prefill_launcher.cu (1)

42-68: Device properties are queried redundantly.

cudaGetDeviceProperties is called both in gdn_prefill (lines 157-161) and again inside gdn_prefill_launcher (lines 43-46). Since sm_count is already passed as a parameter, consider also passing the device major version to avoid the redundant CUDA API call.

🔎 Suggested refactor

Pass the device major version as a parameter to avoid querying device properties twice:

 void gdn_prefill_launcher(void* output, void* output_state, void* q, void* k, void* v,
                           void* input_state, void* alpha, void* beta, int64_t* cu_seqlens,
                           int64_t num_seqs, int64_t num_q_heads, int64_t num_k_heads,
                           int64_t num_v_heads, int64_t num_o_heads, int64_t head_size,
-                          int64_t packed_seq, float scale, int64_t sm_count, DLDataType dtype,
+                          int64_t packed_seq, float scale, int64_t sm_count, int device_major, DLDataType dtype,
                           cudaStream_t stream) {
   DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(dtype, DType, [&] {
-    int dev_id;
-    cudaGetDevice(&dev_id);
-    cudaDeviceProp device_properties;
-    cudaGetDeviceProperties(&device_properties, dev_id);
-
 #if defined(FLAT_SM90A_ENABLED)
-    if (device_properties.major == 9) {
+    if (device_major == 9) {
csrc/flat/hopper/collective/flat_common.hpp (1)

94-146: Complex byte reordering logic for 1-byte elements.

The shuffle-based byte reordering (lines 99-143) implements a specific permutation pattern for converting accumulator layout to operand layout for 1-byte elements. The magic constants (0x3021, 0x2130, 0x1054, 0x5410, 0x3276, 0x7632) encode the permutation.

This is intricate low-level code. Consider adding a comment explaining the mathematical transformation or referencing documentation for the pattern.

tests/gdn/reference_delta_rule.py (2)

110-111: @torch.inference_mode decorator should include parentheses.

The decorator @torch.inference_mode on lines 110 and 331 is missing parentheses. While this works in recent PyTorch versions, the canonical form is @torch.inference_mode().

🔎 Proposed fix
-@torch.inference_mode
+@torch.inference_mode()
 def blockwise_linear_attention(
-@torch.inference_mode
+@torch.inference_mode()
 def blockwise_delta_rule(

Also applies to: 331-332


138-138: Debug artifact: FIXME comment about KVs list.

Line 138 contains KVs = [] # FIXME: kernel debug only. This debug artifact is collected (line 222) but only used in the return value. Consider removing it if not needed for production testing.

csrc/flat/hopper/kernel/flat_options.hpp (1)

57-60: Unused function parameters.

The new_option and options_tuple parameters are unused. Since this is a compile-time operation, only the types matter, but the parameters create unnecessary copies and may trigger compiler warnings.

🔎 Proposed fix
 template <auto kTag, typename Value, typename... Options>
-constexpr auto add_option(Option<kTag, Value> new_option, std::tuple<Options...> options_tuple) {
+constexpr auto add_option(Option<kTag, Value> /*new_option*/, std::tuple<Options...> /*options_tuple*/) {
   return add_option_t<kTag, Value, Options...>();
 }
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (4)

20-23: Clarify the intent of this workaround macro.

The condition thread_idx > 8192 appears to be intentionally unreachable (typical warp groups have far fewer threads). If this is a no-op workaround to prevent compiler optimizations that cause performance issues, consider adding a brief comment explaining the mechanism and referencing any related bug/issue.


123-124: Remove extraneous semicolon.

There's a stray semicolon after the DummyStages type alias.

🔎 Proposed fix
   using DummyStages = cutlass::gemm::collective::StageCount<2>;
-  ;

326-339: Remove commented-out code.

The BetaProcessor struct is entirely commented out. If it's no longer needed, remove it. If it's planned for future use, consider tracking it in an issue rather than leaving dead code.


1217-1221: Simplify valid_seq_len logic.

The ternary condition can be simplified using min.

🔎 Proposed fix
   template <typename WorkDesc>
   CUTE_DEVICE int valid_seq_len(WorkDesc work_desc, int blk_idx) {
-    int remain_len = work_desc.seq_len - BlkSeqKV * blk_idx;
-    return remain_len <= BlkSeqKV ? remain_len : BlkSeqKV;
+    return min(work_desc.seq_len - BlkSeqKV * blk_idx, BlkSeqKV);
   }
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 835a015 and 16e5331.

📒 Files selected for processing (30)
  • benchmarks/bench_gdn_prefill.py
  • csrc/flat/ampere/collective/flat_collective_inverse.hpp
  • csrc/flat/ampere/collective/flat_collective_load.hpp
  • csrc/flat/common.hpp
  • csrc/flat/cute_ext.hpp
  • csrc/flat/debug.hpp
  • csrc/flat/hopper/collective/flat_collective_load.hpp
  • csrc/flat/hopper/collective/flat_collective_store.hpp
  • csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
  • csrc/flat/hopper/collective/flat_common.hpp
  • csrc/flat/hopper/collective/flat_named_barriers.hpp
  • csrc/flat/hopper/device/device_universal.hpp
  • csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp
  • csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
  • csrc/flat/hopper/kernel/flat_options.hpp
  • csrc/flat/hopper/kernel/flat_tile_scheduler.hpp
  • csrc/flat/math.hpp
  • csrc/flat/math_order_barrier.hpp
  • csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu
  • csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh
  • csrc/flat/type_traits.hpp
  • csrc/flat/unused.hpp
  • csrc/gdn_prefill_launcher.cu
  • flashinfer/__init__.py
  • flashinfer/aot.py
  • flashinfer/gdn_prefill.py
  • flashinfer/jit/gdn.py
  • tests/gdn/conftest.py
  • tests/gdn/reference_delta_rule.py
  • tests/gdn/test_prefill_delta_rule.py
🧰 Additional context used
📓 Path-based instructions (6)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/gdn/conftest.py
  • tests/gdn/reference_delta_rule.py
  • tests/gdn/test_prefill_delta_rule.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/jit/gdn.py
  • flashinfer/__init__.py
  • flashinfer/aot.py
  • flashinfer/gdn_prefill.py
flashinfer/jit/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/jit/**/*.py: JIT module generators in flashinfer/jit/ must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Use gen_jit_spec() function to return a properly configured JitSpec from module generators with appropriate sources and extra_cuda_cflags
Specify supported_major_versions in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Files:

  • flashinfer/jit/gdn.py
flashinfer/__init__.py

📄 CodeRabbit inference engine (CLAUDE.md)

Export new operations in flashinfer/__init__.py to make them available as public API

Files:

  • flashinfer/__init__.py
flashinfer/aot.py

📄 CodeRabbit inference engine (CLAUDE.md)

Register new operations in flashinfer/aot.py by calling the gen_*_module() function for AOT (Ahead-Of-Time) pre-compilation support

Files:

  • flashinfer/aot.py
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/gdn_prefill_launcher.cu
  • csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu
🧠 Learnings (11)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • tests/gdn/conftest.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • csrc/flat/common.hpp
  • csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh
  • csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`

Applied to files:

  • flashinfer/jit/gdn.py
  • flashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec

Applied to files:

  • flashinfer/jit/gdn.py
  • flashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/jit/gdn.py
  • flashinfer/aot.py
  • flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Applied to files:

  • flashinfer/jit/gdn.py
  • flashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/__init__.py
  • flashinfer/aot.py
  • flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • flashinfer/aot.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads

Applied to files:

  • csrc/gdn_prefill_launcher.cu
🧬 Code graph analysis (12)
tests/gdn/conftest.py (1)
flashinfer/utils.py (1)
  • is_sm90a_supported (531-533)
csrc/flat/cute_ext.hpp (1)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h (1)
  • layout (29-47)
csrc/flat/math_order_barrier.hpp (1)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (8)
  • void (279-321)
  • void (495-500)
  • void (503-526)
  • void (529-547)
  • void (550-568)
  • void (572-589)
  • void (592-985)
  • void (988-1215)
flashinfer/aot.py (1)
flashinfer/jit/gdn.py (1)
  • gen_gdn_prefill_sm90_module (25-37)
csrc/flat/ampere/collective/flat_collective_load.hpp (1)
csrc/flat/hopper/collective/flat_collective_load.hpp (2)
  • char (17-27)
  • to_string (17-17)
flashinfer/gdn_prefill.py (3)
flashinfer/api_logging.py (1)
  • flashinfer_api (464-565)
flashinfer/jit/gdn.py (1)
  • gen_gdn_prefill_sm90_module (25-37)
csrc/gdn_prefill_launcher.cu (2)
  • gdn_prefill (71-170)
  • gdn_prefill (71-73)
csrc/flat/hopper/collective/flat_collective_store.hpp (2)
csrc/flat/cute_ext.hpp (2)
  • alignment_for_swizzle (33-35)
  • alignment_for_swizzle (33-33)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (2)
  • args (484-486)
  • args (484-484)
csrc/flat/hopper/device/device_universal.hpp (3)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (8)
  • args (484-486)
  • args (484-484)
  • can_implement (411-428)
  • can_implement (411-411)
  • initialize_workspace (489-493)
  • initialize_workspace (489-491)
  • to_underlying_arguments (431-482)
  • to_underlying_arguments (431-432)
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (8)
  • args (174-176)
  • args (174-174)
  • args (178-182)
  • args (178-179)
  • args (184-186)
  • args (184-184)
  • args (197-203)
  • args (197-197)
csrc/flat/hopper/collective/flat_collective_store.hpp (4)
  • initialize_workspace (129-133)
  • initialize_workspace (129-131)
  • to_underlying_arguments (104-119)
  • to_underlying_arguments (104-105)
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (15)
  • args (484-486)
  • args (484-484)
  • initialize_workspace (489-493)
  • initialize_workspace (489-491)
  • params (495-495)
  • to_underlying_arguments (431-482)
  • to_underlying_arguments (431-432)
  • void (279-321)
  • void (495-500)
  • void (503-526)
  • void (529-547)
  • void (550-568)
  • void (572-589)
  • void (592-985)
  • void (988-1215)
csrc/flat/hopper/collective/flat_collective_store.hpp (8)
  • initialize_workspace (129-133)
  • initialize_workspace (129-131)
  • params (135-135)
  • to_underlying_arguments (104-119)
  • to_underlying_arguments (104-105)
  • void (135-137)
  • void (180-216)
  • void (219-242)
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (6)
  • params (69-70)
  • params (70-70)
  • params (98-98)
  • params (98-98)
  • to_underlying_arguments (73-96)
  • to_underlying_arguments (73-76)
csrc/flat/ampere/collective/flat_collective_inverse.hpp (2)
csrc/flat/hopper/collective/flat_common.hpp (6)
  • auto (41-51)
  • auto (55-66)
  • auto (69-71)
  • auto (74-78)
  • auto (81-83)
  • convert_c_layout_to_a_layout (74-74)
csrc/flat/cute_ext.hpp (2)
  • select_tensor (20-30)
  • select_tensor (20-20)
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (3)
csrc/flat/hopper/collective/flat_collective_store.hpp (2)
  • bool (165-177)
  • params (135-135)
csrc/flat/hopper/device/device_universal.hpp (4)
  • params (57-57)
  • params (57-57)
  • params (145-176)
  • params (145-145)
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)
  • params (188-190)
  • params (188-188)
  • params (205-205)
tests/gdn/test_prefill_delta_rule.py (3)
tests/gdn/reference_delta_rule.py (2)
  • exclusive_cumsum (5-9)
  • blockwise_delta_rule (332-478)
flashinfer/gdn_prefill.py (2)
  • gdn_prefill (35-58)
  • chunk_gated_delta_rule (79-200)
tests/gdn/conftest.py (1)
  • qkv_factory (50-51)
🪛 Ruff (0.14.10)
flashinfer/jit/gdn.py

35-35: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

flashinfer/gdn_prefill.py

89-89: Unused function argument: use_qk_l2norm_in_kernel

(ARG001)

benchmarks/bench_gdn_prefill.py

28-28: Unused function argument: num_k_heads

(ARG001)


31-31: Unused function argument: num_seqs

(ARG001)

tests/gdn/test_prefill_delta_rule.py

23-23: Unused function argument: block_size

(ARG001)


211-211: Unused function argument: block_size

(ARG001)

🔇 Additional comments (42)
csrc/flat/unused.hpp (1)

7-19: LGTM!

Clean placeholder type implementation. The variadic constructor and assignment operator correctly discard all inputs, making this suitable as a default template parameter for optional components.

csrc/flat/type_traits.hpp (1)

9-35: LGTM!

The type traits are well-designed. The map_to_cutlass specializations correctly map CUDA native types to their cutlass equivalents, and first_non_void properly handles the recursive case with a clear static assertion for the all-void error case.

csrc/flat/ampere/collective/flat_collective_load.hpp (1)

27-142: LGTM - Well-structured collective load implementation.

The CollectiveLoadVector class correctly implements pipelined loads with proper:

  • Pipeline acquire/commit semantics (lines 113, 132)
  • Memory fence barriers (lines 114, 131)
  • Tail handling with masking (lines 102-108)
  • Optional vector processor support (lines 123-129)
csrc/flat/ampere/collective/flat_collective_inverse.hpp (1)

59-63: Precision limitation noted - verify acceptable for use cases.

The FIXME comment indicates precision issues with half-precision. The static_assert restricts the implementation to half types only. Ensure this limitation is acceptable for the GDN attention use case, or consider adding float support if higher precision is needed.

csrc/flat/hopper/device/device_universal.hpp (2)

56-92: LGTM for workspace and occupancy query logic.

The get_workspace_size, get_grid_shape, and maximum_active_blocks implementations correctly delegate to the Kernel and handle dynamic shared memory configuration appropriately.


145-176: LGTM for the static run() entry point.

The cluster launch path for SM90+ and fallback kernel launch for older architectures is well-structured. Error handling properly checks both CUDA errors and launch status.

flashinfer/__init__.py (1)

87-87: LGTM! Public API export follows project conventions.

The new chunk_gated_delta_rule is correctly exported following the established from .module import symbol as symbol pattern. Based on learnings, this correctly makes the operation available as public API.

flashinfer/aot.py (2)

44-44: LGTM! Import follows project conventions.

The import of gen_gdn_prefill_sm90_module is correctly placed with other JIT module imports.


537-539: LGTM! AOT registration correctly guarded by SM90 capability.

The GDN prefill module is appropriately registered under the add_misc branch with the has_sm90 check, following the established pattern for SM90-specific modules. Based on learnings, this correctly registers the operation for AOT pre-compilation support.

csrc/flat/hopper/collective/flat_named_barriers.hpp (1)

5-12: LGTM! Clean barrier ID catalog design.

The protected NumBarriersUsed = 4 allows derived classes (like DeltaRuleNamedBarriers mentioned in the summary) to safely extend barrier IDs.

Minor observation: index 3 appears unused (indices 0, 1, 2 are defined but NumBarriersUsed = 4). If intentional for future use or alignment, a brief comment could clarify this.

csrc/flat/common.hpp (1)

20-29: LGTM! CUDA_CHECK macro is correctly implemented.

The error handling correctly captures the CUDA error name, code, file, and line number.

tests/gdn/conftest.py (2)

9-13: LGTM! Proper SM90a architecture guard.

The autouse fixture correctly uses is_sm90a_supported from flashinfer.utils to skip tests on unsupported GPUs. Based on learnings, this follows the recommended pattern for GPU architecture-specific tests.


31-46: Tensors are generated on GPU via device context, not on CPU.

The gen_qkv function is called within with torch.device("cuda"): context (lines 46-48 and 237-238 in test_prefill_delta_rule.py). PyTorch's device context manager sets the default device for all tensor operations within that scope, including torch.distributions.Normal.sample() and torch.distributions.Uniform.sample() used by multidist_randu. Tensors are correctly placed on GPU through implicit device context propagation, not left on CPU. No fix is needed.

Likely an incorrect or invalid review comment.

csrc/flat/math_order_barrier.hpp (2)

1-18: LGTM! Well-structured barrier coordination template.

The OrderedNamedBarriers template provides a clean abstraction for multi-warp-group synchronization using named barriers. The conditional type alias for NBId_t correctly handles both reserved and plain barrier IDs.


36-78: Barrier synchronization logic is correct.

The ordered_or_wait and notify_next_blocked methods correctly implement an ordered barrier pattern where:

  1. Each WG waits on its assigned barrier
  2. After completing work, each WG notifies all other WGs' barriers to advance the synchronization state

The detailed comments explaining the barrier state transitions are helpful for understanding the algorithm.

csrc/flat/debug.hpp (1)

47-60: LGTM! TMA descriptor dump is useful for debugging.

The DPRINT_TMA_DESC macro provides a clean way to dump 128 bytes (32 unsigned ints) of TMA descriptor data in a formatted hex layout, which is valuable for debugging TMA-related issues.

flashinfer/gdn_prefill.py (1)

30-75: LGTM! Module caching and op registration follow best practices.

The @functools.cache decorator ensures the module is built once, and the custom op registration pattern is consistent with the codebase. Based on coding guidelines, this correctly implements module-level caching.

csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp (2)

27-35: Persistent scheduler support is blocked with static_assert.

The static_assert(!kIsPersistent, "not implemented") correctly guards against using unimplemented functionality. The commented-out PersistentTileScheduler line suggests this is planned for future work.

Consider removing the commented code (lines 34-35) if persistent scheduler support is not planned for the near term, or add a TODO comment linking to the tracking issue.


37-39: LGTM! Kernel type wiring is correct.

The Kernel alias correctly composes FlatKernelTmaWarpSpecializedDeltaRule with the appropriate CollectiveMainloop and TileScheduler types.

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu (2)

23-69: Consider simplifying the dispatch logic.

The current implementation has 16 explicit branches with 2 "unreachable" cases that are indeed unreachable. This could be simplified using bit-packing or a dispatch table, but the current approach is clear and compiler will likely optimize it.

The exhaustive enumeration ensures compile-time instantiation of all needed template combinations.


74-86: LGTM! Explicit instantiations cover required types.

The explicit template instantiations for half and nv_bfloat16 with float state correctly match the supported dtypes in the Python API.

benchmarks/bench_gdn_prefill.py (1)

245-250: LGTM! SM90 capability check is correct.

The benchmark properly checks for SM90 support before running, providing a clear error message if the device doesn't meet requirements.

csrc/gdn_prefill_launcher.cu (3)

84-94: GQA/GVA validation looks correct.

The branching logic correctly validates the head count relationships for both GQA (num_q_heads >= num_v_heads) and GVA (num_q_heads < num_v_heads) configurations with appropriate ratio checks.


153-155: Default scale calculation is correct.

The default scale of 1.0 / sqrt(head_size) is the standard attention scaling factor and is appropriately applied when the user passes scale == 0.0.


99-104: This concern is unfounded.

The CHECK_SHAPE macro calls check_shape(), which uses TVM_FFI_ICHECK_EQ to validate conditions. On failure, TVM_FFI_ICHECK_EQ throws a C++ exception that immediately unwinds the stack, preventing any code after the check from executing. Therefore, line 103 (data_ptr()) will never be reached if the shape validation fails. The code is safe from the described null pointer access risk.

Likely an incorrect or invalid review comment.

csrc/flat/hopper/collective/flat_collective_load.hpp (2)

29-41: Constructor stores references - ensure lifetime correctness.

The constructor stores references to tma_load, pipeline, and storage. This is appropriate for device-side usage within a kernel's lifetime, but callers must ensure these objects outlive the CollectiveLoadTma instance.


87-105: Pipeline state is incremented unconditionally.

The ++dst_pipe on line 103 is inside the if (lane_predicate == 1) block, meaning only one thread increments the pipeline state. This is intentional for single-producer semantics but verify that all threads in the warp maintain consistent pipeline state views if needed elsewhere.

csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)

380-386: Cluster/block synchronization handled correctly.

The synchronization logic properly distinguishes between cluster-level (cluster_arrive_relaxed/cluster_wait) and single-block (__syncthreads) synchronization based on ClusterShape size.


206-220: Warp role enums are well-defined.

The WarpGroupRole and LdStWarpRole enums clearly define the responsibilities of each warp/warp-group, enabling clean role-based dispatch in the kernel body.


464-467: Remove this concern—the dealloc vs alloc usage is intentional.

Lines 448 and 467 use different register allocation strategies (alloc vs dealloc) because they operate on different register magnitudes. StateMmaRegisterRequirement is computed as (total_registers - load_registers - aux_registers) / num_state_mma_warp_groups, which typically yields values >= 128. In contrast, AuxMmaRegisterRequirement is fixed at 128 - load_registers (88 or 104), which is always < 128. The kernel follows the CUTLASS pattern of using warpgroup_reg_dealloc for smaller allocations (< 128 registers) and warpgroup_reg_alloc for larger ones, as evidenced in related code (e.g., fmha_common.hpp). This asymmetry is correct and reflects different resource needs for the state versus auxiliary MMA warp groups.

csrc/flat/hopper/collective/flat_common.hpp (1)

10-29: GEMM accumulator reset logic is correct.

The gemm_reset_zero_acc function correctly handles both 2x2x1 and 3x3x3 tensor ranks, iterating over the K dimension and resetting the accumulator scale to One after each GEMM operation.

csrc/flat/hopper/collective/flat_collective_store.hpp (3)

56-58: Hardcoded assumption about element size.

Line 57 has static_assert(sizeof(SmemElementO) == 2), limiting this implementation to 16-bit types (fp16/bf16). This is documented but may need updating if other types are supported in the future.


218-242: Tensormap tail handling uses correct synchronization.

The create_tensormap_for_tail function:

  1. Copies the base tensormap (lines 226-231)
  2. Synchronizes the warp (line 232)
  3. Updates the global dimension (lines 234-238)
  4. Synchronizes again (line 239)
  5. Issues a release fence (line 241)

This follows the correct pattern for PTX tensormap manipulation.


164-177: can_process logic handles edge cases correctly.

The function correctly identifies when TMA can be used:

  • Intermediate full tiles (line 167-169)
  • Last tile that's full (line 170, first condition)
  • Last sequence where OOB can be handled by TMA (line 170, second condition)
tests/gdn/reference_delta_rule.py (3)

232-308: delta_rule function provides element-wise reference implementation.

This function implements the delta rule step-by-step for each token, serving as a clear reference for correctness verification. The implementation follows the mathematical formulation from the paper correctly.


360-368: GQA/GVA handling with repeat_interleave is correct.

The head expansion logic correctly handles both GQA and GVA configurations using repeat_interleave to broadcast heads appropriately.


12-29: matmul helper handles mixed precision correctly.

The function correctly upcasts fp16/bf16 inputs to fp32 for computation, then downcasts the result. Note: bf16 inputs return fp32 output (line 25), while fp16 inputs return fp16 output (line 27). Verify this asymmetry is intentional.

csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (1)

106-116: Runtime validation is available but not automatically enforced in the execution path.

The code defines strict mathematical relationships for GQA and GVA modes through tag selection (line 11-12). A can_implement() function validates these relationships at runtime (checking that head counts satisfy the GQA/GVA constraints), but this validation is not automatically called during kernel execution—callers must explicitly verify it succeeds before launching. Without this check, mismatched tag selection could allow the divisions at lines 109 and 113 to compute incorrect results or encounter division by zero.

Ensure any public API or usage examples require can_implement() validation before kernel launch.

csrc/flat/hopper/kernel/flat_options.hpp (1)

9-13: Clean compile-time options system design.

The Option struct with tag-based lookup and the Tag enum provide a flexible, type-safe mechanism for kernel configuration. The use of template metaprogramming for find_option_t and add_option_t is idiomatic for CUDA/CUTLASS codebases.

Also applies to: 62-77

csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (3)

274-324: AlphaProcessor implementation looks correct.

The warp-level prefix sum using shuffle operations (__shfl_up_sync, __shfl_sync) for computing cumulative log/product of alpha values is well-structured. The use of log2f/exp2f pairs and the epsilon (1e-10f) for numerical stability in Line 293 are appropriate for this computation.


591-984: Complex compute function with solid warp-specialized structure.

The compute function handles the delta rule main loop with proper pipeline synchronization, ordered math barriers for warpgroup coordination, and clear separation between Q@KV, QK@V, and KV update phases. The use of CUTE_NO_UNROLL for variable-iteration loops and INLINE_LAMBDA for hot-path lambdas is appropriate.


904-908: [Your rewritten review comment text here]
[Exactly ONE classification tag]

Comment on lines +164 to +174
template <typename TensorT>
CUTE_DEVICE void blockwise_diagonal_inversed_4x4_to_8x8(TensorT&& mat) {
constexpr auto L =
typename std::remove_const_t<std::remove_reference_t<TensorT>>::layout_type{};
static_assert(rank(L) == 2);
static_assert(size<0>(L) == 8);
static_assert(size<1>(L) == 8);
auto mat_NxN_2x2 = flat_divide(std::forward<TensorT>(mat), Shape<_4, _4>{});

// FIXME: implement
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

rg -n "blockwise_diagonal_inversed_4x4_to_8x8" --type cpp

Repository: flashinfer-ai/flashinfer

Length of output: 202


Remove unimplemented dead code.

blockwise_diagonal_inversed_4x4_to_8x8 is not called anywhere in the codebase and contains only a FIXME placeholder. Remove this function to avoid confusion.

🤖 Prompt for AI Agents
csrc/flat/ampere/collective/flat_collective_inverse.hpp around lines 164-174:
remove the unused, unimplemented function blockwise_diagonal_inversed_4x4_to_8x8
entirely; delete its template declaration, body and related local variable so
there is no dead code or FIXME left in the file, and run a quick grep to confirm
it’s not referenced elsewhere before committing.

Comment on lines +370 to +371
using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate type alias declaration.

CopyOpO_S2R is defined twice on consecutive lines. This appears to be a copy-paste error. The second definition shadows the first, though both are identical so it compiles without error.

🔎 Proposed fix
     using CopyOpD_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
     using CopyOpC_S2R = std::conditional_t<is_col_major, SM75_U32x4_LDSM_N, SM75_U16x8_LDSM_T>;
     using CopyOpA_S2R = std::conditional_t<is_col_major, SM75_U32x2_LDSM_N, SM75_U16x4_LDSM_T>;
     using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
-    using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
🤖 Prompt for AI Agents
In csrc/flat/ampere/collective/flat_collective_inverse.hpp around lines 370-371
there is a duplicate type alias declaration for CopyOpO_S2R on consecutive
lines; remove the redundant second declaration (or the duplicate line) so the
alias is declared only once, keeping the existing conditional_t definition
intact and preserving surrounding formatting.

Comment on lines +19 to +30
template <int... Is, typename Tensor>
__forceinline__ __host__ __device__ constexpr auto select_tensor(Tensor&& t) {
if constexpr (is_composed_layout<decltype(t.layout())>::value) {
return make_tensor(
std::forward<Tensor>(t).data(),
make_composed_layout(std::forward<Tensor>(t).layout().layout_a(),
std::forward<Tensor>(t).layout().offset(),
select<Is...>(std::forward<Tensor>(t).layout().layout_b())));
} else {
return make_tensor(std::forward<Tensor>(t).data(), select<Is...>(t.layout()));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Multiple std::forward calls on the same parameter may cause undefined behavior.

In the composed layout branch (lines 22-26), std::forward<Tensor>(t) is called multiple times on the same parameter. If t is an rvalue, the first forward may move from it, leaving subsequent accesses in an undefined state.

🔎 Proposed fix
 template <int... Is, typename Tensor>
 __forceinline__ __host__ __device__ constexpr auto select_tensor(Tensor&& t) {
   if constexpr (is_composed_layout<decltype(t.layout())>::value) {
+    auto& layout = t.layout();
     return make_tensor(
         std::forward<Tensor>(t).data(),
-        make_composed_layout(std::forward<Tensor>(t).layout().layout_a(),
-                             std::forward<Tensor>(t).layout().offset(),
-                             select<Is...>(std::forward<Tensor>(t).layout().layout_b())));
+        make_composed_layout(layout.layout_a(),
+                             layout.offset(),
+                             select<Is...>(layout.layout_b())));
   } else {
     return make_tensor(std::forward<Tensor>(t).data(), select<Is...>(t.layout()));
   }
 }
🤖 Prompt for AI Agents
In csrc/flat/cute_ext.hpp around lines 19 to 30, the composed-layout branch
calls std::forward<Tensor>(t) multiple times which can move from an rvalue more
than once and lead to undefined behavior; fix this by forwarding t exactly once
into a local variable (e.g., auto&& tmp = std::forward<Tensor>(t)) at the top of
the function and then use that local tmp for data(), layout(), layout_a(),
offset(), and layout_b(), ensuring all subsequent accesses use the single
forwarded object.

Comment on lines +410 to +428
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
auto ratio = problem_size.num_q_heads > problem_size.num_v_heads
? problem_size.num_q_heads / problem_size.num_v_heads
: problem_size.num_v_heads / problem_size.num_q_heads;

constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value;

bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) &&
(problem_size.num_q_heads == ratio * problem_size.num_k_heads) &&
(problem_size.num_q_heads == ratio * problem_size.num_v_heads);

bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) &&
(problem_size.num_v_heads == ratio * problem_size.num_q_heads) &&
(problem_size.num_v_heads == ratio * problem_size.num_k_heads);
return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) &&
(problem_size.head_size <= get<2>(TileShape{})) &&
((problem_size.head_size % Alignment) == 0);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential division by zero in ratio calculation.

If problem_size.num_v_heads is zero (or num_q_heads when it's smaller), the ratio calculation will produce undefined behavior. Although the caller likely validates head counts, adding a guard or assertion would make this safer.

🔎 Proposed fix
   template <class ProblemShape>
   static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
+    if (problem_size.num_q_heads == 0 || problem_size.num_v_heads == 0) {
+      return false;
+    }
     auto ratio = problem_size.num_q_heads > problem_size.num_v_heads
                      ? problem_size.num_q_heads / problem_size.num_v_heads
                      : problem_size.num_v_heads / problem_size.num_q_heads;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
auto ratio = problem_size.num_q_heads > problem_size.num_v_heads
? problem_size.num_q_heads / problem_size.num_v_heads
: problem_size.num_v_heads / problem_size.num_q_heads;
constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value;
bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) &&
(problem_size.num_q_heads == ratio * problem_size.num_k_heads) &&
(problem_size.num_q_heads == ratio * problem_size.num_v_heads);
bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) &&
(problem_size.num_v_heads == ratio * problem_size.num_q_heads) &&
(problem_size.num_v_heads == ratio * problem_size.num_k_heads);
return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) &&
(problem_size.head_size <= get<2>(TileShape{})) &&
((problem_size.head_size % Alignment) == 0);
}
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_size, Arguments const& args) {
if (problem_size.num_q_heads == 0 || problem_size.num_v_heads == 0) {
return false;
}
auto ratio = problem_size.num_q_heads > problem_size.num_v_heads
? problem_size.num_q_heads / problem_size.num_v_heads
: problem_size.num_v_heads / problem_size.num_q_heads;
constexpr bool IsGVAEnabled = find_option_t<Tag::kIsGVA, false_type, Options>::value;
bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) &&
(problem_size.num_q_heads == ratio * problem_size.num_k_heads) &&
(problem_size.num_q_heads == ratio * problem_size.num_v_heads);
bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) &&
(problem_size.num_v_heads == ratio * problem_size.num_q_heads) &&
(problem_size.num_v_heads == ratio * problem_size.num_k_heads);
return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) &&
(problem_size.head_size <= get<2>(TileShape{})) &&
((problem_size.head_size % Alignment) == 0);
}
🤖 Prompt for AI Agents
In
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
around lines 410-428, the ratio calculation can divide by zero when num_v_heads
or num_q_heads is zero; add a guard before computing ratio to return false (or
assert) if either problem_size.num_q_heads == 0 or problem_size.num_v_heads ==
0, then compute ratio safely (integer division is fine here) and proceed; ensure
the early-return/validation is placed before any use of ratio so the function
never performs a division by zero.

Comment on lines +30 to +34
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical bug: static local variable causes shared state across all instances.

The is_initialized function uses a static bool, meaning all instances of Universal<Kernel_> share the same initialization flag. If you create two Universal objects and initialize one, the second will incorrectly believe it's already initialized and skip cudaFuncSetAttribute.

🔎 Proposed fix: use instance member instead of static local
  private:
   /// Kernel API parameters object
   Params params_;
+  bool initialized_ = false;

-  bool is_initialized(bool set = false) {
-    static bool initialized = false;
-    if (set) initialized = true;
-    return initialized;
-  }
+  bool is_initialized() const { return initialized_; }
+  void set_initialized() { initialized_ = true; }

Then update usages in initialize():

-    if (is_initialized()) return Status::kSuccess;
+    if (is_initialized()) return Status::kSuccess;
     // ... smem setup code ...
-    is_initialized(true);
+    set_initialized();

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In csrc/flat/hopper/device/device_universal.hpp around lines 30-34, the
is_initialized(bool set=false) function uses a static local variable causing
shared initialization state across all Universal<Kernel_> instances; replace the
static local with a non-static instance member (e.g., bool initialized_
initialized to false in the class) and update initialize() and any call sites to
use this member (use initialized_ = true when setting and return initialized_
when querying) so each object tracks its own initialization state.

Comment on lines +104 to +107
.ptr_Q = (T*)q, .dQ = {q_tok_stride, _1{}, q_head_stride},
.ptr_K = (T*)k, .dK = {k_tok_stride, _1{}, k_head_stride},
.ptr_V = (T*)v, .dV = {v_tok_stride, _1{}, v_head_stride},
.ptr_O = (T*)output, .dO = {o_tok_stride, _1{}, o_head_stride},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Casting away const from input pointers is unsafe.

The input pointers q, k, v are declared as TQKV const*, but the casts (T*)q, (T*)k, (T*)v strip the const qualifier. If the kernel unexpectedly writes to these buffers, this results in undefined behavior. If the API genuinely requires non-const pointers but guarantees read-only access, consider documenting this or using const_cast explicitly to make the intent clear.

🔎 Suggested improvement

If the underlying API truly requires mutable pointers but won't modify them:

-                .ptr_Q = (T*)q,      .dQ = {q_tok_stride, _1{}, q_head_stride},
-                .ptr_K = (T*)k,      .dK = {k_tok_stride, _1{}, k_head_stride},
-                .ptr_V = (T*)v,      .dV = {v_tok_stride, _1{}, v_head_stride},
+                // SAFETY: cutlass API requires non-const but these are read-only
+                .ptr_Q = const_cast<T*>(reinterpret_cast<T const*>(q)),
+                .ptr_K = const_cast<T*>(reinterpret_cast<T const*>(k)),
+                .ptr_V = const_cast<T*>(reinterpret_cast<T const*>(v)),

Better yet, if possible, update the Arguments type to accept T const* for input tensors.

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +117 to +127
@pytest.mark.parametrize("beta", [False, True])
@pytest.mark.parametrize("alpha", [False, True])
@pytest.mark.parametrize("scale", [1.0, "auto"])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize(
"num_q_heads, num_k_heads, num_v_heads",
[(1, 1, 1), (4, 1, 1), (3, 3, 3), (6, 2, 2), (1, 1, 2), (2, 2, 4)],
)
@pytest.mark.parametrize("seq_lens", [[64], [128], [256], [256, 256], [64, 128, 512]])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing SM90 architecture check per coding guidelines.

The coding guidelines specify using flashinfer.utils functions like is_sm90a_supported to skip tests on unsupported GPU architectures. This test requires SM90 but doesn't check for it.

Suggested fix

Add at the top of the test functions or as a module-level skip:

from flashinfer.utils import is_sm90a_supported

# At module level or in each test:
@pytest.mark.skipif(not is_sm90a_supported(), reason="GDN requires SM90a")

As per coding guidelines, test implementations should use flashinfer.utils functions to skip tests on unsupported GPU architectures.

🤖 Prompt for AI Agents
In tests/gdn/test_prefill_delta_rule.py around lines 117 to 127, the
parametrized test lacks a check to skip execution on GPUs without SM90 support;
add an import from flashinfer.utils (is_sm90a_supported) and apply a
pytest.mark.skipif(not is_sm90a_supported(), reason="GDN requires SM90a") either
at the module top for all tests or directly on the affected test function(s) so
the test is skipped on unsupported architectures.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was done in tests/gdn/conftest.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119, you're absolutely right! I apologize for the incorrect comment.

I can see from tests/gdn/conftest.py that there is an autouse fixture skip_if_not_sm90a that automatically applies the SM90a check to all tests in the tests/gdn/ directory using is_sm90a_supported() from flashinfer.utils. This is actually a better approach than adding the decorator to each individual test function, as it:

  1. ✅ Uses the correct flashinfer.utils.is_sm90a_supported() function per the coding guidelines
  2. ✅ Applies the skip automatically to all GDN tests without repetition
  3. ✅ Keeps the test code cleaner and more maintainable

The implementation correctly follows the architecture check guidelines. Thank you for pointing this out!


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
csrc/flat/prefill/prefill_kernel.hpp (1)

14-22: Add documentation for this public API function.

Consider adding function-level documentation (e.g., Doxygen-style comments) that describes:

  • The purpose of the delta rule prefill kernel
  • Template parameter requirements and constraints
  • Parameter descriptions (especially the distinction between num_v_heads and num_o_heads)
  • Expected behavior and preconditions
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 16e5331 and 5052358.

📒 Files selected for processing (1)
  • csrc/flat/prefill/prefill_kernel.hpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
csrc/flat/prefill/prefill_kernel.hpp (3)

1-6: LGTM!

Standard header guard and includes are appropriate for this kernel launcher declaration.


7-10: LGTM!

The forward declaration is a good practice to minimize compilation dependencies while allowing the launcher signature to reference the SM90 architecture tag.


20-21: These parameters serve different purposes in delta rule grouped attention and are intentional, not a copy-paste error.

The num_v_heads parameter dimensions the value computation and controls attention grouping logic (GQA/GVA patterns), while num_o_heads dimensions the output and state tensor storage. In grouped attention scenarios, these can legitimately differ, making both parameters necessary. The codebase includes explicit documentation linking num_o_heads to num_sab_heads (state/alpha/beta), confirming the intentional design.


namespace flat {

template <typename ArchTag, // TODO: hide this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Resolve or remove the TODO comment before merging.

The TODO comment indicates the ArchTag template parameter should be hidden from the public API, suggesting the API design is not finalized. Consider one of the following approaches:

  • Move ArchTag determination inside the function using type traits or runtime detection
  • Use a non-template wrapper function that internally dispatches to the appropriate architecture
  • If the current design is final, remove the TODO
🤖 Prompt for AI Agents
In csrc/flat/prefill/prefill_kernel.hpp around line 14, remove or resolve the
TODO about hiding the ArchTag template parameter: either (A) move architecture
selection inside the function by deriving ArchTag with type traits or runtime
detection and make the function non-templated on ArchTag, (B) provide a
non-template public wrapper that dispatches to template instantiations of the
implementation based on detected/derived architecture, or (C) if the template
parameter is intended to remain part of the public API, delete the TODO and add
a clarifying comment documenting the design decision; implement one of these
options and update comments/tests accordingly.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 1, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !222 has been created, and the CI pipeline #41026587 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #41026587: 9/20 passed

@yzh119
Copy link
Collaborator

yzh119 commented Jan 2, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !222 has been updated with latest changes, and the CI pipeline #41058613 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
flashinfer/gdn_prefill.py (1)

91-91: [Duplicate] The use_qk_l2norm_in_kernel parameter is unused.

This parameter is defined in the function signature and documented, but it is not passed to the underlying gdn_prefill kernel call at line 186. Tests and benchmarks perform L2 normalization on k in Python before calling this function, confirming this parameter is not implemented.

🧹 Nitpick comments (1)
tests/gdn/test_prefill_delta_rule.py (1)

23-23: Unused block_size parameter in test helpers.

The block_size parameter is accepted by both _test_prefill_kernel (line 23) and _test_chunked_prefill (line 211) but is never used within the function bodies. This parameter is included in the test parametrization but serves no purpose in the current implementation.

🔎 Suggested fix

If block_size is not needed for validation or other purposes, remove it from the function signatures and test parametrization:

 def _test_prefill_kernel(
     qkv_factory,
     dtype: str,
     num_q_heads: int,
     num_k_heads: int,
     num_v_heads: int,
     head_size: int,
-    block_size: int,
     seq_lens: list[int],
     scale: float,
     alpha: bool,
     beta: bool,
     seed: int | None = None,
 ):

And remove from parametrization:

-@pytest.mark.parametrize("block_size", [64])

Apply the same changes to _test_chunked_prefill and test_chunked_prefill.

Also applies to: 211-211

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5052358 and a75b943.

📒 Files selected for processing (2)
  • flashinfer/gdn_prefill.py
  • tests/gdn/test_prefill_delta_rule.py
🧰 Additional context used
📓 Path-based instructions (2)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/gdn_prefill.py
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/gdn/test_prefill_delta_rule.py
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • tests/gdn/test_prefill_delta_rule.py
🧬 Code graph analysis (1)
tests/gdn/test_prefill_delta_rule.py (3)
tests/gdn/reference_delta_rule.py (2)
  • exclusive_cumsum (5-9)
  • blockwise_delta_rule (332-478)
flashinfer/gdn_prefill.py (2)
  • gdn_prefill (37-60)
  • chunk_gated_delta_rule (81-202)
tests/gdn/conftest.py (1)
  • qkv_factory (50-51)
🪛 Ruff (0.14.10)
flashinfer/gdn_prefill.py

91-91: Unused function argument: use_qk_l2norm_in_kernel

(ARG001)

tests/gdn/test_prefill_delta_rule.py

23-23: Unused function argument: block_size

(ARG001)


211-211: Unused function argument: block_size

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines +118 to +120
scale (Optional[float]):
Scale factor for the attention scores.
If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Inconsistency between docstring and implementation for scale default.

The docstring (lines 118-120) states that if scale is not provided, it "defaults to 1 / sqrt(head_size)", but the implementation (line 196) passes 0.0 to the kernel when scale is None.

If the kernel interprets 0.0 as a sentinel value to apply the default scaling internally, this behavior should be documented in the docstring. Otherwise, either compute the default scale here or update the docstring to reflect the actual behavior.

🔎 Suggested clarification

Option 1: If kernel handles the default, update the docstring:

         scale (Optional[float]):
-            Scale factor for the attention scores.
-            If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``.
+            Scale factor for the attention scores.
+            If None or 0.0, the kernel applies the default scale of ``1 / sqrt(head_size)``. Default: ``None``.

Option 2: If the default should be computed here:

-        scale if scale is not None else 0.0,
+        scale if scale is not None else (1.0 / (head_size ** 0.5)),

Also applies to: 196-196

🤖 Prompt for AI Agents
In flashinfer/gdn_prefill.py around lines 118-120 (docstring) and line 196 (call
site), the docstring claims scale defaults to 1/sqrt(head_size) but the code
passes 0.0 when scale is None; make them consistent by either computing the
default locally or documenting the 0.0 sentinel. Fix by one of two concise
changes: (A) if you want the module to compute the default, replace the None
branch so that when scale is None you compute scale = 1.0 / math.sqrt(head_size)
and pass that to the kernel (ensure head_size is available), or (B) if the
kernel internally treats 0.0 as "use default", update the docstring to state
that passing 0.0 signals the kernel to use 1/sqrt(head_size) internally; apply
the chosen change at line 196 and update the docstring accordingly.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #41058613: 13/20 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 enabled auto-merge (squash) January 3, 2026 07:50
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

♻️ Duplicate comments (12)
csrc/flat/hopper/device/device_universal.hpp (2)

45-49: Critical bug: static local variable causes shared state across all instances.

This issue was already flagged in a previous review. The is_initialized function uses a static bool, meaning all instances of Universal<Kernel_> share the same initialization flag.


158-160: Typo in comment: "to_underling_arguments" should be "to_underlying_arguments".

This issue was already flagged in a previous review.

csrc/flat/ampere/collective/flat_collective_inverse.hpp (2)

179-189: Remove unimplemented dead code.

blockwise_diagonal_inversed_4x4_to_8x8 contains only a FIXME placeholder and is not called anywhere in the codebase. Remove this function to avoid confusion.

🤖 Prompt for AI Agents
csrc/flat/ampere/collective/flat_collective_inverse.hpp around lines 179-189:
remove the unused, unimplemented function blockwise_diagonal_inversed_4x4_to_8x8
entirely; delete its template declaration, body and related local variable so
there is no dead code or FIXME left in the file.

385-386: Duplicate type alias declaration.

CopyOpO_S2R is defined twice on consecutive lines. This appears to be a copy-paste error. The second definition shadows the first, though both are identical so it compiles without error.

🔎 Proposed fix
     using CopyOpD_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
     using CopyOpC_S2R = std::conditional_t<is_col_major, SM75_U32x4_LDSM_N, SM75_U16x8_LDSM_T>;
     using CopyOpA_S2R = std::conditional_t<is_col_major, SM75_U32x2_LDSM_N, SM75_U16x4_LDSM_T>;
     using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
-    using CopyOpO_S2R = std::conditional_t<is_col_major, SM75_U16x8_LDSM_T, SM75_U32x4_LDSM_N>;
csrc/flat/common.hpp (1)

26-33: Critical: Fix CHECK macro syntax errors.

Line 30 contains two syntax errors that will cause compilation failure:

  1. ##expr should be #expr for proper stringification
  2. Missing comma between msg and __FILE__
🔎 Proposed fix
 #define CHECK(expr, msg)                                                                           \
   do {                                                                                             \
     if (!(expr)) {                                                                                 \
       std::string buffer(1024, '\0');                                                              \
-      sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \
+      sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", #expr, msg, __FILE__, __LINE__); \
       throw std::runtime_error(buffer.c_str());                                                    \
     }                                                                                              \
   } while (0)
flashinfer/gdn_prefill.py (2)

118-120: Clarify the scale default mechanism in the docstring.

The docstring states the default is 1/sqrt(head_size), but the implementation passes 0.0 to the kernel (line 196), which the kernel interprets as a sentinel to apply the default. Consider clarifying this in the docstring to avoid confusion.

🔎 Suggested clarification
         scale (Optional[float]):
             Scale factor for the attention scores.
-            If not provided, defaults to ``1 / sqrt(head_size)``. Default: ``None``.
+            Scale factor for the attention scores.
+            If None (or 0.0), the kernel applies the default scale of ``1 / sqrt(head_size)``. Default: ``None``.

91-91: Unused parameter use_qk_l2norm_in_kernel.

This parameter is defined in the signature and documented, but it's never used in the function body or passed to the kernel. Consider either implementing the functionality or removing the parameter to avoid user confusion.

csrc/flat/prefill/prefill_kernel.hpp (1)

29-37: TODO: Consider tracking the API design decision.

The TODO comment indicates ArchTag should be hidden from the public API. If this is intentional for now, consider creating an issue to track this future improvement, or document why the current design was chosen.

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (1)

119-124: Casting away const from input pointers.

The casts (T*)q, (T*)k, (T*)v strip the const qualifier from input pointers. If the underlying CUTLASS API requires mutable pointers but guarantees read-only access, consider adding a comment documenting this safety invariant.

csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (2)

117-118: Stray semicolon on line 118.

There's an extraneous semicolon after the seq_idx declaration.

🔎 Proposed fix
     int32_t seq_idx;
-    ;
     int32_t q_head_idx;

133-145: Out-of-bounds access if seq_idx is invalid.

Lines 133-134 access problem_size.cu_seqlens[seq_idx] and problem_size.cu_seqlens[seq_idx + 1] before validating whether the block should be scheduled. If blockIdx.x results in an invalid seq_idx, this causes undefined behavior. Consider moving the scheduled check before accessing cu_seqlens.

benchmarks/bench_gdn_prefill.py (1)

199-201: Dead code: unused calculation result.

This calculation is not assigned to any variable or used. Remove the dead code.

🔎 Suggested fix
     # Get device info for bandwidth calculation
     props = torch.cuda.get_device_properties(0)
-    props.total_memory * 2 / 1e12  # Approximate peak bandwidth
🧹 Nitpick comments (18)
csrc/flat/hopper/device/device_universal.hpp (2)

75-107: Consider extracting duplicated cudaFuncSetAttribute logic.

The dynamic shared memory configuration code (lines 82-92 and 128-138) is duplicated between maximum_active_blocks() and initialize(). Both blocks check smem_size >= (48 << 10) and call cudaFuncSetAttribute with identical error handling.

🔎 Proposed refactor: extract into a helper method
+ private:
+  static Status configure_smem_size() {
+    int smem_size = Kernel::SharedStorageSize;
+    if (smem_size >= (48 << 10)) {
+      CUTLASS_TRACE_HOST("  Setting smem size to " << smem_size);
+      cudaError_t result = cudaFuncSetAttribute(
+          device_kernel<Kernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+      if (cudaSuccess != result) {
+        result = cudaGetLastError();  // to clear the error bit
+        CUTLASS_TRACE_HOST("  cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
+        return Status::kErrorInternal;
+      }
+    }
+    return Status::kSuccess;
+  }

Then use configure_smem_size() in both locations.

Also applies to: 124-140


160-191: Consider making params const in static run() method.

The params parameter is passed as a non-const reference but appears to be read-only within the method. Making it const Params& would better express intent and prevent accidental modifications.

🔎 Proposed change
-  static Status run(Params& params, cudaStream_t stream = nullptr) {
+  static Status run(Params const& params, cudaStream_t stream = nullptr) {
csrc/flat/ampere/collective/flat_collective_inverse.hpp (1)

74-82: Clarify the precision concern.

The FIXME at Line 76 indicates "precision is not good due to half" but provides no guidance on acceptable error bounds or whether this is a known limitation. For production code, either:

  • Document the expected precision characteristics and acceptable use cases, or
  • Track this as a TODO for future FP32 support with a reference to an issue
csrc/flat/hopper/collective/flat_common.hpp (2)

25-50: Document the accumulator reset behavior.

The functions gemm_reset_zero_acc and gemm_zero_acc have subtle side effects on atom.accumulate_:

  • gemm_zero_acc zeros the accumulator initially, then accumulates (adds) on subsequent k_blocks
  • gemm_reset_zero_acc assumes the accumulator is already initialized and switches to accumulation after the first k_block

The relationship between these functions and their naming is not immediately clear. Consider:

  1. Adding function-level documentation explaining when to use each
  2. Clarifying that atom.accumulate_ is modified as a side effect
  3. Documenting that the caller's atom state will be changed to ScaleOut::One after these calls

23-23: Consider avoiding using namespace in header files.

While convenient, using namespace cute; in a header file brings all cute symbols into any translation unit that includes this header, which can lead to name collisions and reduced readability.

Consider using explicit cute:: qualifications or targeted using-declarations for specific symbols if the project style permits.

flashinfer/jit/gdn.py (1)

35-35: Consider iterable unpacking for better style.

While the current list concatenation works correctly, iterable unpacking is more idiomatic Python.

🔎 Proposed refactor
-        extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
+        extra_cuda_cflags=[*sm90a_nvcc_flags, "-DFLAT_SM90A_ENABLED", "-std=c++20"],
csrc/flat/math_order_barrier.hpp (1)

47-49: Address persistent scheduler compatibility.

The FIXME comment indicates a known issue with persistent schedulers. The destructor is currently empty, which may cause resource leaks or incorrect barrier states when kernels are reused.

Would you like me to help design a solution for persistent scheduler compatibility, or should this be tracked in a separate issue?

csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (1)

121-131: Consider adding assertions for head ratio invariants.

The code assumes num_q_heads >= num_v_heads for GQA and num_v_heads >= num_q_heads for GVA. If these invariants are violated, the divisions on lines 124 and 128 may produce incorrect results (e.g., zero divisor from integer division). Consider adding debug assertions to catch API misuse early.

🔎 Suggested defensive check
     if constexpr (std::is_same_v<GroupingTag, GQATag>) {
+      assert(params.num_q_heads >= params.num_v_heads && "GQA requires num_q_heads >= num_v_heads");
       seq_idx = blockIdx.x / params.num_q_heads;
       q_head_idx = blockIdx.x % params.num_q_heads;
       v_head_idx = q_head_idx / (params.num_q_heads / params.num_v_heads);
     } else if constexpr (std::is_same_v<GroupingTag, GVATag>) {
+      assert(params.num_v_heads >= params.num_q_heads && "GVA requires num_v_heads >= num_q_heads");
       seq_idx = blockIdx.x / params.num_v_heads;
       v_head_idx = blockIdx.x % params.num_v_heads;
       q_head_idx = v_head_idx / (params.num_v_heads / params.num_q_heads);
csrc/flat/hopper/collective/flat_collective_store.hpp (1)

72-73: Consider documenting the 16-bit type constraint.

The static_assert(sizeof(SmemElementO) == 2) enforces that only 16-bit types (e.g., half, bfloat16) are supported. A brief comment explaining this constraint would help future maintainers understand the design choice.

+  // SW32 swizzle layout requires 16-bit elements (half/bfloat16)
   static_assert(sizeof(SmemElementO) == 2);
csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (1)

134-148: Consider including status codes in error messages.

The error messages are generic (e.g., "can_implement failed"). Including the cutlass::Status code would help debugging.

🔎 Suggested improvement
     status = op.can_implement(arguments);
     if (status != cutlass::Status::kSuccess) {
-      throw std::runtime_error("can_implement failed");
+      throw std::runtime_error("can_implement failed with status: " + 
+                               std::to_string(static_cast<int>(status)));
     }
csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp (1)

42-54: LGTM with a minor cleanup suggestion.

The kernel builder specialization is well-structured. The static_assert for persistent mode clearly indicates work-in-progress.

Consider removing the commented-out code (lines 49-50) as it adds noise. If this is intentional for future reference, a TODO comment would be clearer.

🔎 Suggested cleanup
   using GroupingTag = std::conditional_t<kIsGVA, GVATag, GQATag>;
   using TileScheduler = flat::kernel::IndividualTileScheduler<GroupingTag>;
-  // using TileScheduler = std::conditional_t<kIsPersistent, flat::kernel::PersistentTileScheduler,
-  // flat::kernel::IndividualTileScheduler>;
 
   using Kernel = flat::kernel::FlatKernelTmaWarpSpecializedDeltaRule<CollectiveMainloop,
                                                                      TileScheduler, Options>;
csrc/gdn_prefill_launcher.cu (2)

43-46: Redundant CUDA device property queries.

gdn_prefill already queries device properties (lines 157-160) and passes sm_count to the launcher. The device major version check here could use the same mechanism to avoid duplicate queries.

Consider passing the device major version or making the check in gdn_prefill before calling the launcher, or caching the result.


157-161: Missing CUDA error checks.

The cudaGetDevice and cudaGetDeviceProperties calls do not check return values. While failures are rare, unchecked errors could lead to undefined behavior if the device ID is invalid.

🔎 Suggested improvement
   int dev_id;
-  cudaGetDevice(&dev_id);
+  CUDA_CHECK(cudaGetDevice(&dev_id));
   cudaDeviceProp device_properties;
-  cudaGetDeviceProperties(&device_properties, dev_id);
+  CUDA_CHECK(cudaGetDeviceProperties(&device_properties, dev_id));
   int32_t sm_count = device_properties.multiProcessorCount;
csrc/flat/hopper/collective/flat_collective_load.hpp (1)

98-98: Clarify cluster support limitation.

The comment "do not support cluster" could be more informative. Consider adding context about why cluster support is not implemented or when it might be added.

benchmarks/bench_gdn_prefill.py (1)

25-56: Unused parameters in gdn_flops function.

The parameters num_k_heads and num_seqs are unused. If they're kept for API consistency with gdn_bytes, add a leading underscore or comment explaining why they're present but unused.

🔎 Option 1: Prefix with underscore
 def gdn_flops(
     total_seq_len: int,
     num_q_heads: int,
-    num_k_heads: int,
+    _num_k_heads: int,
     num_v_heads: int,
     head_size: int,
-    num_seqs: int,
+    _num_seqs: int,
 ) -> int:
tests/gdn/test_prefill_delta_rule.py (1)

39-39: Unused block_size parameter.

The block_size parameter is passed to both _test_prefill_kernel and _test_chunked_prefill but never used. Either use it in the test logic or remove it from the function signatures and parametrization.

csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (1)

35-58: Register requirement calculation with magic numbers.

The register calculation uses hardcoded values (40, 128, 64*1024, 248, 8). Consider adding comments explaining the rationale for these constants, particularly:

  • Why load_registers differs between debug and non-debug builds
  • The significance of 248 as the max register cap
tests/gdn/reference_delta_rule.py (1)

28-45: Clarify the asymmetric dtype return behavior.

The matmul function returns different dtypes depending on input:

  • When either input is bfloat16, it returns float32 (line 41)
  • When inputs are float16, it returns float16 (line 43)

This asymmetric behavior could be confusing. If this is intentional for numerical stability reasons, consider adding a comment explaining the rationale. Otherwise, consider making the return dtype consistent.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a75b943 and a0fb220.

📒 Files selected for processing (29)
  • benchmarks/bench_gdn_prefill.py
  • csrc/flat/ampere/collective/flat_collective_inverse.hpp
  • csrc/flat/ampere/collective/flat_collective_load.hpp
  • csrc/flat/common.hpp
  • csrc/flat/cute_ext.hpp
  • csrc/flat/debug.hpp
  • csrc/flat/hopper/collective/flat_collective_load.hpp
  • csrc/flat/hopper/collective/flat_collective_store.hpp
  • csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
  • csrc/flat/hopper/collective/flat_common.hpp
  • csrc/flat/hopper/collective/flat_named_barriers.hpp
  • csrc/flat/hopper/device/device_universal.hpp
  • csrc/flat/hopper/kernel/flat_kernel_builder_delta_rule.hpp
  • csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
  • csrc/flat/hopper/kernel/flat_options.hpp
  • csrc/flat/hopper/kernel/flat_tile_scheduler.hpp
  • csrc/flat/math.hpp
  • csrc/flat/math_order_barrier.hpp
  • csrc/flat/prefill/prefill_kernel.hpp
  • csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu
  • csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh
  • csrc/flat/type_traits.hpp
  • csrc/flat/unused.hpp
  • csrc/gdn_prefill_launcher.cu
  • flashinfer/gdn_prefill.py
  • flashinfer/jit/gdn.py
  • tests/gdn/conftest.py
  • tests/gdn/reference_delta_rule.py
  • tests/gdn/test_prefill_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • csrc/flat/unused.hpp
  • tests/gdn/conftest.py
  • csrc/flat/cute_ext.hpp
  • csrc/flat/hopper/collective/flat_named_barriers.hpp
  • csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp
🧰 Additional context used
📓 Path-based instructions (4)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/gdn_prefill.py
  • flashinfer/jit/gdn.py
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/gdn_prefill_launcher.cu
  • csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/gdn/reference_delta_rule.py
  • tests/gdn/test_prefill_delta_rule.py
flashinfer/jit/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/jit/**/*.py: JIT module generators in flashinfer/jit/ must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Use gen_jit_spec() function to return a properly configured JitSpec from module generators with appropriate sources and extra_cuda_cflags
Specify supported_major_versions in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Files:

  • flashinfer/jit/gdn.py
🧠 Learnings (12)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/gdn_prefill.py
  • flashinfer/jit/gdn.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/gdn_prefill.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • csrc/gdn_prefill_launcher.cu
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads

Applied to files:

  • csrc/gdn_prefill_launcher.cu
  • csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • csrc/flat/prefill/prefill_kernel.hpp
  • csrc/flat/ampere/collective/flat_collective_inverse.hpp
  • csrc/flat/common.hpp
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/flat/ampere/collective/flat_collective_inverse.hpp
  • csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp
  • csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • csrc/flat/common.hpp
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • tests/gdn/test_prefill_delta_rule.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`

Applied to files:

  • flashinfer/jit/gdn.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec

Applied to files:

  • flashinfer/jit/gdn.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Applied to files:

  • flashinfer/jit/gdn.py
🧬 Code graph analysis (8)
flashinfer/gdn_prefill.py (3)
flashinfer/api_logging.py (1)
  • flashinfer_api (464-565)
flashinfer/jit/gdn.py (1)
  • gen_gdn_prefill_sm90_module (25-37)
csrc/gdn_prefill_launcher.cu (2)
  • gdn_prefill (71-170)
  • gdn_prefill (71-73)
csrc/flat/math_order_barrier.hpp (1)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (8)
  • void (294-336)
  • void (510-515)
  • void (518-541)
  • void (544-562)
  • void (565-583)
  • void (587-604)
  • void (607-1000)
  • void (1003-1230)
csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (3)
csrc/flat/hopper/collective/flat_collective_store.hpp (2)
  • bool (180-192)
  • params (150-150)
csrc/flat/hopper/device/device_universal.hpp (4)
  • params (72-72)
  • params (72-72)
  • params (160-191)
  • params (160-160)
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (3)
  • params (203-205)
  • params (203-203)
  • params (220-220)
benchmarks/bench_gdn_prefill.py (2)
flashinfer/gdn_prefill.py (2)
  • gdn_prefill (37-60)
  • chunk_gated_delta_rule (81-202)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1484-1631)
csrc/flat/hopper/device/device_universal.hpp (3)
csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp (9)
  • args (499-501)
  • args (499-499)
  • can_implement (426-443)
  • can_implement (426-426)
  • params (510-510)
  • initialize_workspace (504-508)
  • initialize_workspace (504-506)
  • to_underlying_arguments (446-497)
  • to_underlying_arguments (446-447)
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (12)
  • args (189-191)
  • args (189-189)
  • args (193-197)
  • args (193-194)
  • args (199-201)
  • args (199-199)
  • args (212-218)
  • args (212-212)
  • params (203-205)
  • params (203-203)
  • params (220-220)
  • block (207-210)
csrc/flat/hopper/collective/flat_collective_store.hpp (5)
  • params (150-150)
  • initialize_workspace (144-148)
  • initialize_workspace (144-146)
  • to_underlying_arguments (119-134)
  • to_underlying_arguments (119-120)
csrc/flat/hopper/collective/flat_collective_store.hpp (3)
csrc/flat/cute_ext.hpp (2)
  • alignment_for_swizzle (48-50)
  • alignment_for_swizzle (48-48)
csrc/flat/ampere/collective/flat_collective_inverse.hpp (6)
  • void (85-115)
  • void (119-171)
  • void (180-189)
  • void (192-275)
  • void (278-360)
  • void (363-476)
include/flashinfer/attention/blackwell/kernel/fmha_tile_scheduler.hpp (1)
  • make_coord (94-97)
csrc/flat/ampere/collective/flat_collective_load.hpp (1)
csrc/flat/hopper/collective/flat_collective_load.hpp (2)
  • char (32-42)
  • to_string (32-32)
flashinfer/jit/gdn.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (216-397)
  • gen_jit_spec (400-466)
🪛 Ruff (0.14.10)
flashinfer/gdn_prefill.py

91-91: Unused function argument: use_qk_l2norm_in_kernel

(ARG001)

benchmarks/bench_gdn_prefill.py

28-28: Unused function argument: num_k_heads

(ARG001)


31-31: Unused function argument: num_seqs

(ARG001)

tests/gdn/test_prefill_delta_rule.py

39-39: Unused function argument: block_size

(ARG001)


227-227: Unused function argument: block_size

(ARG001)

flashinfer/jit/gdn.py

35-35: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

🔇 Additional comments (42)
csrc/flat/ampere/collective/flat_collective_inverse.hpp (5)

27-72: LGTM! Well-structured helper functions.

The is_contiguous template correctly handles both scalar and ranked layouts, and the SM80 namespace helpers properly convert fragment layouts and accumulator tensors for GEMM operations.


84-115: LGTM! Well-orchestrated multi-stage inverse.

The staged approach (8×8 → 16×16 → 32×32 → 64×64) with proper barrier synchronization ensures correct dependencies between block inverse computations.


117-171: LGTM! Efficient warp-shuffle based inversion.

The implementation uses warp shuffle operations for efficient thread communication during Gauss-Jordan elimination. The conditional handling of garbage-filled regions is well-structured, and the FP32 intermediate computation helps mitigate precision concerns.


192-360: LGTM! Correct implementation of block matrix inversion.

Both blockwise_diagonal_inversed_8x8_to_16x16 and blockwise_diagonal_inversed_16x16_to_32x32 correctly implement the block inverse formula documented at lines 173-177. The tensor partitioning, tiled MMA operations, and type conversions are well-structured.


362-476: LGTM! Complex cross-warp reduction correctly implemented.

Despite the duplicate typedef issue (flagged separately), the warp-level parallelization and cross-warp reduction logic is correct. The careful buffer reuse with barriers (line 458) and the split x==0/x==1 write-reduce pattern ensure correct results.

csrc/flat/hopper/collective/flat_common.hpp (3)

1-166: AI summary describes a different file.

The AI-generated summary claims this file:

  • "Introduces three public macros: FLAT_UNUSED_PARAMETER(x), CHECK(expr, msg), CUDA_CHECK(expr)"
  • "Includes debug.hpp and provides centralized error handling"

However, the actual file contains GEMM/GMMA utility functions in the flat::collective namespace with no macros, no debug.hpp include, and no error-handling utilities. The summary appears to describe an entirely different file.


52-86: LGTM!

The three convert_to_gmma_rs overloads correctly handle different MMA atom and TiledMMA types. The second overload appropriately defaults to Major::K for both operands when major layout parameters are not provided.


88-93: Verify divisibility assumption in layout conversion.

Line 90 performs integer division size<0>(c) / size(a) without checking if size<0>(c) is divisible by size(a). If this assumption is violated, the layout conversion will produce incorrect results silently.

Consider:

  1. Adding a static_assert or runtime check to validate the divisibility
  2. Documenting the precondition that size<0>(c) must be a multiple of size(a)
csrc/flat/math_order_barrier.hpp (1)

27-97: Implementation looks solid.

The ordered barrier coordination logic is well-documented with clear state-transition comments. The use of variadic templates for flexible WG-to-NB mapping is elegant.

csrc/flat/debug.hpp (1)

20-78: Debug infrastructure is well-designed.

The conditional compilation approach ensures zero overhead when debugging is disabled. The variety of print macros (thread-filtered, warp-filtered, warp-group-filtered) provides flexible granularity for debugging CUDA kernels.

csrc/flat/common.hpp (1)

35-44: CUDA_CHECK macro looks correct.

The macro properly captures the CUDA error, checks the result, formats an error message with error name and code, and throws a runtime exception. The implementation follows best practices for CUDA error handling.

csrc/flat/type_traits.hpp (2)

24-32: Type mapping utility looks correct.

The map_to_cutlass trait provides a clean abstraction for mapping native CUDA types (half, nv_bfloat16) to their Cutlass equivalents. The specializations cover the expected type conversions.


34-49: first_non_void implementation is correct.

The variadic template metafunction correctly:

  • Guards against all-void parameter packs with a static_assert
  • Returns the first non-void type via partial specialization
  • Recursively skips void types using inheritance

This is a clean, idiomatic implementation of type selection at compile time.

csrc/flat/hopper/kernel/flat_tile_scheduler.hpp (1)

29-71: LGTM!

The WorkDesc struct is well-designed with proper compile-time branching for GQA vs GVA grouping semantics. The accessor methods provide a clean abstraction over the internal head index representation.

csrc/flat/hopper/collective/flat_collective_store.hpp (3)

45-53: LGTM!

The smid() helper correctly retrieves the streaming multiprocessor ID using inline PTX assembly, with proper __CUDA_ARCH__ guarding for host/device compilation.


194-231: LGTM!

The step() method correctly implements pipelined TMA stores with proper barrier synchronization. The tail-tile handling via modified tensormap is a good approach for handling variable-length sequences without padding.


233-257: LGTM!

The tensormap creation for tail tiles is well-implemented. The parallel copy across warp lanes, followed by __syncwarp() barriers and PTX fence operations, correctly ensures the modified descriptor is visible before use.

csrc/flat/hopper/kernel/flat_options.hpp (2)

24-76: LGTM!

The compile-time option system is well-designed using standard template metaprogramming patterns. The recursive find_option_impl correctly handles tag lookup with default fallback, and add_option_t provides clean option composition.


77-92: LGTM!

The Tag enum provides a comprehensive set of configuration options for the kernel builder. The inline comments explaining the purpose of each tag are helpful for understanding the kernel configuration semantics.

csrc/flat/math.hpp (1)

24-34: LGTM!

The ceil_log2 and next_power_of_two implementations are correct for their intended use case. The recursive ceil_log2 correctly computes the ceiling of log₂(n), yielding the smallest exponent e such that 2^e >= n.

flashinfer/gdn_prefill.py (2)

30-77: LGTM on the module setup!

The module correctly uses @functools.cache for caching (per coding guidelines) and properly declares mutates_args=("output", "output_state") for the custom op registration.


178-184: LGTM!

The comment correctly explains that output state must always be allocated because the kernel unconditionally writes to it. This is the right approach to ensure kernel correctness.

csrc/flat/prefill/prefill_kernel.hpp (1)

22-25: LGTM!

The forward declaration of cutlass::arch::Sm90 is a good practice to avoid including full CUTLASS headers in this interface header.

csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cuh (2)

57-77: LGTM!

The compile-time options construction using an IIFE with decltype is an elegant pattern for building the constexpr options tuple based on template parameters.


131-132: LGTM!

The workspace allocation using cutlass::device_memory::allocation follows RAII semantics, ensuring proper cleanup when the function returns.

csrc/gdn_prefill_launcher.cu (1)

71-170: Well-structured validation and kernel invocation.

The gdn_prefill function has comprehensive input validation including:

  • GQA/GVA head count consistency checks
  • Shape and dtype validation for all tensors
  • Proper handling of optional tensors (input_state, alpha, beta)
  • Default scale computation

The TVM-FFI integration follows the coding guidelines for framework bindings in csrc/.

csrc/flat/hopper/collective/flat_collective_load.hpp (1)

44-100: Well-designed TMA collective load abstraction.

The CollectiveLoadTma template cleanly separates Q vs K/V loading patterns with appropriate tile shapes and layouts. The compile-time branching via if constexpr ensures no runtime overhead.

benchmarks/bench_gdn_prefill.py (1)

102-178: Well-structured benchmark function.

The benchmark follows best practices:

  • Proper input tensor setup with L2-normalized keys for numerical stability
  • Warmup call before timing
  • Uses bench_gpu_time with CUPTI for accurate GPU kernel timing
  • Calculates meaningful metrics (TFLOPs, TB/s)
csrc/flat/prefill/prefill_kernel_delta_rule_sm90.cu (2)

38-86: Exhaustive dispatch pattern for kernel configuration.

The dispatch logic correctly covers all 16 combinations of the 4 boolean flags (is_gva, needs_beta, needs_alpha, init_state). The #undef LAUNCH at line 86 properly cleans up the macro.

The "unreachable" exceptions are correct dead code since the if-else chains are exhaustive.


89-101: Explicit template instantiations provided.

Both FP16 (half) and BF16 (nv_bfloat16) instantiations are provided, matching the dtype dispatch in gdn_prefill_launcher.

tests/gdn/test_prefill_delta_rule.py (2)

32-131: Well-structured test helper with proper validation.

The _test_prefill_kernel function:

  • Correctly skips unstable cases (no alpha and no beta)
  • Properly seeds RNG for reproducibility
  • L2-normalizes keys for numerical stability
  • Uses appropriate tolerances per dtype
  • Correctly transposes state to match reference implementation layout

220-377: Comprehensive chunked prefill test.

The _test_chunked_prefill function correctly:

  • Tests state continuity across chunks by passing our_state1 to the second call
  • Concatenates variable-length sequences for reference comparison
  • Uses the same tolerance strategy as the basic test
csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp (2)

405-499: Warp-specialized kernel execution pattern.

The kernel correctly implements warp-group specialization with:

  • LdSt warp group handling Q/K/V loads, O stores, and alpha/beta loads via separate warps
  • Math0/Math1 warp groups performing state computation with barrier coordination
  • MathA auxiliary warp group for preprocessing

The register allocation pattern (reg_dealloc for LdSt/MathA, reg_alloc for Math0/Math1) balances registers between load/store and compute paths.


60-218: Well-organized kernel struct with clear abstractions.

The FlatKernelTmaWarpSpecializedDeltaRule struct cleanly encapsulates:

  • Pipeline types and state management
  • Shared storage layout with proper alignment
  • Problem shape and argument handling
  • Static configuration (thread counts, register requirements)
csrc/flat/ampere/collective/flat_collective_load.hpp (3)

27-40: LGTM!

The LoadKindVector enum and to_string helper follow the same pattern as the Hopper implementation and are correctly implemented.


42-108: LGTM!

The CollectiveLoadVector template structure and partition_SD method are well-implemented. The tensor partitioning logic correctly handles the setup for collective loads with appropriate tail masking.


110-157: LGTM!

The step method correctly implements the pipelined collective load with proper tail handling, memory fencing, and optional vector processing. The implementation follows CUDA best practices for shared memory synchronization.

tests/gdn/reference_delta_rule.py (5)

21-25: LGTM!

The exclusive_cumsum implementation is straightforward and correct for computing sequence offsets.


48-74: LGTM!

Both LambdaQ and LambdaK correctly implement decay factor computation with appropriate handling for partial blocks. The comment on line 61 helpfully highlights the crucial difference for tail blocks.


78-123: LGTM!

The _linear_attention function correctly implements causal linear attention with optional decay factors and custom QK weights.


248-324: LGTM!

The delta_rule function correctly implements the delta rule algorithm with proper handling for GQA and GVA configurations. The state update logic (lines 312-317) follows the mathematical formulation described in the comment.


348-493: LGTM!

The blockwise_delta_rule function implements a complex blockwise delta rule algorithm with proper handling for:

  • GQA/GVA configurations (lines 376-383)
  • Tail block padding with appropriate defaults (lines 425-430: alpha=1.0, beta=0.0)
  • Block-level state updates with decay factors
  • Inter-block and intra-block attention contributions

The implementation is mathematically sound and handles edge cases appropriately.

Comment on lines +1 to +20
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/pipeline/sm90_pipeline.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there's a corresponding hopper directory and whether this file should be there
echo "=== Checking for hopper collective directories ==="
fd -t d "hopper.*collective" csrc/

echo -e "\n=== Checking for other files in ampere that use sm90_pipeline ==="
rg -n "sm90_pipeline" csrc/flat/ampere/ 2>/dev/null || echo "No matches found"

echo -e "\n=== List all ampere directories ==="
fd -t d "ampere" csrc/

echo -e "\n=== Check hopper directory structure ==="
fd -t d "hopper" csrc/ | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 428


🏁 Script executed:

#!/bin/bash
# Check what's inside hopper directory
echo "=== Contents of csrc/flat/hopper/ ==="
find csrc/flat/hopper/ -type f -o -type d | sort

echo -e "\n=== Search for sm90_pipeline usage across entire codebase ==="
rg -l "sm90_pipeline" csrc/ --type cpp --type hpp | head -20

echo -e "\n=== Check if there's a hopper/collective subdirectory ==="
ls -la csrc/flat/hopper/ 2>/dev/null
</script>

Repository: flashinfer-ai/flashinfer

Length of output: 1295


🏁 Script executed:

#!/bin/bash
# Compare the two flat_collective_load.hpp files
echo "=== Ampere version (first 50 lines) ==="
head -50 csrc/flat/ampere/collective/flat_collective_load.hpp

echo -e "\n=== Hopper version (first 50 lines) ==="
head -50 csrc/flat/hopper/collective/flat_collective_load.hpp

echo -e "\n=== Check if files are identical ==="
if cmp -s csrc/flat/ampere/collective/flat_collective_load.hpp csrc/flat/hopper/collective/flat_collective_load.hpp; then
  echo "Files are IDENTICAL"
else
  echo "Files are DIFFERENT"
  echo "Ampere file size: $(wc -c < csrc/flat/ampere/collective/flat_collective_load.hpp)"
  echo "Hopper file size: $(wc -c < csrc/flat/hopper/collective/flat_collective_load.hpp)"
fi

Repository: flashinfer-ai/flashinfer

Length of output: 3301


🏁 Script executed:

#!/bin/bash
# Check all files in ampere/collective and their sm90_pipeline usage
echo "=== All files in ampere/collective ==="
ls -1 csrc/flat/ampere/collective/

echo -e "\n=== Check sm90_pipeline references in ampere/collective ==="
grep -n "sm90_pipeline" csrc/flat/ampere/collective/*.hpp

echo -e "\n=== Check sm90_pipeline references in hopper/collective ==="
grep -n "sm90_pipeline" csrc/flat/hopper/collective/*.hpp

echo -e "\n=== Check if other hopper files use sm90_pipeline ==="
grep -r "sm90_pipeline" csrc/flat/hopper/ --include="*.hpp"

Repository: flashinfer-ai/flashinfer

Length of output: 639


🏁 Script executed:

#!/bin/bash
# Read the full ampere collective_load file to check for architectural context
echo "=== Full ampere/collective/flat_collective_load.hpp ==="
cat csrc/flat/ampere/collective/flat_collective_load.hpp

echo -e "\n=== Check for any build guards or architecture checks ==="
grep -i "sm80\|ampere\|hopper\|sm90\|arch" csrc/flat/ampere/collective/flat_collective_load.hpp

Repository: flashinfer-ai/flashinfer

Length of output: 6182


Remove this file from csrc/flat/ampere/collective/ or add architecture guards.

This file in the ampere (SM80) directory imports cutlass/pipeline/sm90_pipeline.hpp on line 20 without any architecture guards. The SM90 pipeline is specific to Hopper (SM90+) architecture and should not be used in Ampere code without proper conditional compilation.

Either:

  • Move this to the hopper directory (the correct implementation already exists at csrc/flat/hopper/collective/flat_collective_load.hpp), or
  • Replace the SM90 pipeline with an architecture-appropriate pipeline for Ampere and add proper #ifdef guards if the ampere implementation is intentional.
🤖 Prompt for AI Agents
In csrc/flat/ampere/collective/flat_collective_load.hpp lines 1-20, this Ampere
directory file includes the Hopper-specific header
cutlass/pipeline/sm90_pipeline.hpp; remove or correct it. Either delete this
file from csrc/flat/ampere/collective/ if the implementation belongs in hopper,
or replace the sm90 include with the appropriate Ampere pipeline header and add
conditional compilation guards (e.g. #if defined(CUTLASS_SM80) ||
defined(CUTLASS_SM_80) / #elif defined(CUTLASS_SM90) ... / #endif or use an
equivalent project macro) so the SM90 pipeline is only included when building
for SM90/Hopper; alternatively move the file to csrc/flat/hopper/collective/ to
match the existing hopper implementation.

Comment on lines +37 to +48
#define DPRINTF_W(fmt, ...) \
if (IS_PRINT_BLOCK) \
printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \
threadIdx.x, ##__VA_ARGS__)
#define DPRINTF0_W(fmt, ...) \
if (IS_PRINT_BLOCK && threadIdx.x % 32 == 0) \
printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \
threadIdx.x, ##__VA_ARGS__)
#define DPRINTF_WG(fmt, ...) \
if (IS_PRINT_BLOCK) \
printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \
threadIdx.x, ##__VA_ARGS__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n csrc/flat/debug.hpp | sed -n '30,55p'

Repository: flashinfer-ai/flashinfer

Length of output: 1951


DPRINTF_W and DPRINTF_WG are identical and should differ.

Both macros have identical implementations despite their distinct naming. The "0" variants clearly establish a pattern where DPRINTF0_W filters at warp level (threadIdx.x % 32 == 0) and DPRINTF0_WG filters at warp-group level (threadIdx.x % 128 == 0). The non-"0" variants should follow the same convention: consider whether DPRINTF_WG should filter on warp-group boundaries (e.g., threadIdx.x % 128 == 0) or whether the current unconditional behavior is intentional.

🤖 Prompt for AI Agents
In csrc/flat/debug.hpp around lines 37 to 48, DPRINTF_W and DPRINTF_WG are
identical but should differ: keep DPRINTF_W as the current per-thread print
(only gated by IS_PRINT_BLOCK) and change DPRINTF_WG to only print on warp-group
boundaries by adding a condition like threadIdx.x % 128 == 0 in its if check
(i.e., if (IS_PRINT_BLOCK && threadIdx.x % 128 == 0)); keep the same printf
format and arguments to maintain consistent log output.

#pragma once

#include "cute/tensor.hpp"
#include "cutlass/kernel_hardware_info.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if kernel_hardware_info is used in this file
echo "=== Checking for kernel_hardware_info usage in flat_common.hpp ==="
rg -n 'kernel_hardware_info|KernelHardwareInfo' csrc/flat/hopper/collective/flat_common.hpp

echo ""
echo "=== Checking the file structure and includes ==="
head -30 csrc/flat/hopper/collective/flat_common.hpp

echo ""
echo "=== Checking for files that include flat_common.hpp ==="
rg -l 'flat_common\.hpp' --type=cpp --type=cu csrc/ | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 1417


🏁 Script executed:

#!/bin/bash
# Check the entire file content
echo "=== Full content of flat_common.hpp ==="
cat csrc/flat/hopper/collective/flat_common.hpp

echo ""
echo "=== Files including flat_common.hpp ==="
rg -l 'flat_common\.hpp' csrc/ | head -20

echo ""
echo "=== Checking if kernel_hardware_info is used in files that include flat_common.hpp ==="
rg -l 'kernel_hardware_info' csrc/flat/hopper/collective/ | grep -v flat_common.hpp

Repository: flashinfer-ai/flashinfer

Length of output: 6884


Remove unused include.

The cutlass/kernel_hardware_info.h header is included but not used anywhere in this file. Remove it.

🤖 Prompt for AI Agents
In csrc/flat/hopper/collective/flat_common.hpp around line 19, remove the unused
include directive for "cutlass/kernel_hardware_info.h" — delete that include
line and ensure no other symbols from that header are referenced elsewhere in
this file; run a local build to confirm no missing dependencies after removal.

Comment on lines +109 to +161
if constexpr (sizeof(Element) == 1) {
// 00 11 22 33 00 11 22 33 acc layout
// 00 00 11 11 22 22 33 33 operand layout
// BB AA AA BB AA BB BB AA conflict-free exchange pattern
// 16-bit exchange; so process two at a time potentially
int tid = threadIdx.x % 4;
auto values_u32 = recast<uint32_t>(operand);

CUTE_UNROLL
for (int n = 0; n < size<1>(values_u32); n++) {
CUTE_UNROLL
for (int k = 0; k < size<2>(values_u32); k++) {
CUTE_UNROLL
for (int ii = 0; ii < 8; ii += 4) {
uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k);
uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k);

// step A:
// t 1 v 0 -> t 0 v 1
// t 2 v 0 -> t 1 v 0
// t 0 v 1 -> t 2 v 0
// t 3 v 1 -> t 3 v 1

int v_to_send = tid == 1 || tid == 2 ? 0 : 1;
int v_to_recv = v_to_send;
int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF;

uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1;

values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4);

// step B:
// t 0 v 0 -> t 0 v 0
// t 3 v 0 -> t 1 v 1
// t 1 v 1 -> t 2 v 1
// t 2 v 1 -> t 3 v 0

v_to_send = 1 - v_to_send;
v_to_recv = 1 - v_to_recv;
t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF;

uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1;

values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4);

values_u32(ii / 2 + 0, n, k) =
__byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410);
values_u32(ii / 2 + 1, n, k) =
__byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632);
}
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Document the purpose and assumptions of the 1-byte element shuffle logic.

This code implements a complex warp-shuffle-based data exchange for 1-byte elements, but lacks critical documentation:

  1. Why is this exchange needed? The comments describe the exchange pattern but not the underlying problem (e.g., avoiding bank conflicts, matching expected layout, etc.)

  2. Thread grouping assumption: Line 114 uses threadIdx.x % 4, assuming threads are processed in groups of 4. This assumption should be documented.

  3. Magic constants: The hex values 0x3021, 0x2130, 0x1054, 0x5410, 0x3276, 0x7632 encode specific patterns. Consider:

    • Adding named constants with explanatory comments
    • Documenting the bit layout of these values
    • Explaining the relationship between the values and the exchange pattern
  4. Layout transformation: Document the input and output layouts explicitly (e.g., "transforms from ACC layout 00 11 22 33 00 11 22 33 to operand layout 00 00 11 11 22 22 33 33")

🤖 Prompt for AI Agents
In csrc/flat/hopper/collective/flat_common.hpp around lines 109-161, the 1-byte
element warp-shuffle exchange logic lacks documentation for intent, assumptions
and magic constants; update the code by adding a brief header comment that
states the purpose (why this exchange is needed and the layout transformation
performed, e.g., transforms ACC layout "00 11 22 33 00 11 22 33" to operand
layout "00 00 11 11 22 22 33 33"), explicitly document the thread grouping
assumption (that threadIdx.x is grouped modulo 4 and why), and replace inline
hex magic numbers with named constants and short comments explaining their
bit/byte encoding and how they map to the described exchange steps (A/B) and
shuffle masks so future readers can reason about the pattern quickly.

Comment on lines +249 to +250
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused variable block_rank_in_cluster.

block_rank_in_cluster is computed but never used in the function. Remove it or add a comment if it's reserved for future cluster support.

🔎 Suggested fix
     int lane_predicate = cute::elect_one_sync();
-    uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
int lane_predicate = cute::elect_one_sync();
🤖 Prompt for AI Agents
In csrc/flat/hopper/kernel/flat_kernel_tma_warpspecialized_delta_rule.hpp around
lines 249-250, the variable block_rank_in_cluster is computed but never used;
either remove the declaration entirely or explicitly mark it as intentionally
unused (e.g., add a comment "reserved for future cluster support" and a
(void)block_rank_in_cluster; or use an attribute/pragma to suppress unused
warnings) so the compiler doesn’t emit an unused-variable warning and intent is
clear.

Comment on lines +25 to +37
def gen_gdn_prefill_sm90_module() -> JitSpec:
return gen_jit_spec(
name="gdn_prefill_launcher",
sources=[
jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu",
jit_env.FLASHINFER_CSRC_DIR
/ "flat"
/ "prefill"
/ "prefill_kernel_delta_rule_sm90.cu",
],
extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Specify supported GPU architectures for the SM90 module.

The JitSpec should include supported_major_versions to restrict compilation to SM90 and newer architectures. This prevents attempting to compile SM90-specific code on incompatible GPUs.

🔎 Proposed fix
 def gen_gdn_prefill_sm90_module() -> JitSpec:
     return gen_jit_spec(
         name="gdn_prefill_launcher",
         sources=[
             jit_env.FLASHINFER_CSRC_DIR / "gdn_prefill_launcher.cu",
             jit_env.FLASHINFER_CSRC_DIR
             / "flat"
             / "prefill"
             / "prefill_kernel_delta_rule_sm90.cu",
         ],
         extra_cuda_cflags=sm90a_nvcc_flags + ["-DFLAT_SM90A_ENABLED", "-std=c++20"],
         extra_include_paths=[jit_env.FLASHINFER_CSRC_DIR],
+        supported_major_versions=[9, 10, 11, 12],
     )

Based on coding guidelines for flashinfer/jit/**/*.py.

🧰 Tools
🪛 Ruff (0.14.10)

35-35: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

🤖 Prompt for AI Agents
In flashinfer/jit/gdn.py around lines 25 to 37, the JitSpec for the SM90 module
is missing a supported_major_versions constraint; update the gen_jit_spec call
to include supported_major_versions set to only SM90 (e.g.
supported_major_versions=[90]) so the launcher is only compiled on SM90-or-newer
GPUs, and place this parameter alongside the existing extra_cuda_cflags and
extra_include_paths arguments.

k = k.repeat_interleave(num_qo_heads // num_kv_heads, dim=1)
v = v.repeat_interleave(num_qo_heads // num_kv_heads, dim=1)

KVs = [] # FIXME: kernel debug only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Address the FIXME: Remove or make debug output optional.

Line 154 has a FIXME comment indicating KVs is for kernel debugging only. This debug list accumulates state for every block (line 238) and is returned as part of the function output (line 245), which could cause memory issues for long sequences.

Consider either:

  • Removing the debug code if no longer needed, or
  • Making it optional via a parameter (e.g., return_debug_info=False)
🔎 Proposed fix to make debug output optional
 @torch.inference_mode
 def blockwise_linear_attention(
     q: torch.Tensor,  # [total_seq_len, num_qo_heads, head_size]
     k: torch.Tensor,  # [total_seq_len, num_kv_heads, head_size]
     v: torch.Tensor,  # [total_seq_len, num_kv_heads, head_size]
     seq_lens: list[int],  # sequence length for each sequence
     block_size: int = 32,
     scale_factor=1.0,
     decay_factor: float
     | torch.Tensor = 1.0,  # float or tensor with num_elems == num_qo_heads
     decay_exponent_offset=0,
     kv_dtype: torch.dtype = torch.float32,
+    return_debug_info: bool = False,
 ) -> torch.Tensor:
     num_qo_heads = q.size(1)
     head_size = q.size(2)
     num_kv_heads = k.size(1)
 
     if scale_factor != 1.0:
         k = k * scale_factor
     if isinstance(decay_factor, float):
         decay_factor = torch.ones(num_qo_heads) * decay_factor
         decay_factor = decay_factor.to(q.device)
     assert decay_factor.numel() == num_qo_heads
     decay_factor = decay_factor.reshape(num_qo_heads, 1, 1)
 
     k = k.repeat_interleave(num_qo_heads // num_kv_heads, dim=1)
     v = v.repeat_interleave(num_qo_heads // num_kv_heads, dim=1)
 
-    KVs = []  # FIXME: kernel debug only
+    KVs = [] if return_debug_info else None
     kv = torch.zeros(
         (len(seq_lens), num_qo_heads, head_size, head_size),
         dtype=kv_dtype,
         device=q.device,
     )
     output = torch.zeros_like(q)
     
     # ... rest of the function ...
     
-            KVs.append(carried_kv.clone())
+            if return_debug_info:
+                KVs.append(carried_kv.clone())
 
             blk_offset += block_size
 
         # print(kv.shape, carried_kv.shape)
         kv[seq_idx, :, :] = carried_kv
 
-    return output, kv, KVs
+    if return_debug_info:
+        return output, kv, KVs
+    else:
+        return output, kv

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tests/gdn/reference_delta_rule.py around line 154 (and referencing behavior
at lines 238 and 245), the KVs list is a persistent debug accumulator marked
FIXME and can grow unbounded; make debug output optional by adding a new
parameter (e.g., return_debug_info=False) to the function signature, initialize
and append to KVs only when return_debug_info is True, and conditionally include
KVs in the function return (when True return the existing result plus debug
data, otherwise return the original result unchanged); update any call
sites/tests to use the new parameter as needed.

Comment on lines +327 to +336
def identity_add_strict_lower_diagonal(m: torch.Tensor):
SIZE = m.size(-1)
assert m.size(-2) == SIZE
with torch.device(m.device):
m = m.clone()
mask = torch.arange(SIZE).unsqueeze(1) <= torch.arange(SIZE)
m[:, mask] = 0.0
# m[mask.unsqueeze(0)] = 0.0
m = m + torch.eye(SIZE).unsqueeze(0)
return m
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix the misleading function name.

The function identity_add_strict_lower_diagonal actually zeros the upper triangle (not the lower). Line 332 creates a mask where row <= col, which selects the upper triangular elements including the diagonal. Line 333 then zeros these elements.

The function effectively computes: strict_lower_triangle(m) + identity, so the name is backwards from the implementation.

Consider renaming to identity_add_strict_lower_triangle or keep_strict_lower_plus_identity for clarity.

🔎 Proposed fix
-def identity_add_strict_lower_diagonal(m: torch.Tensor):
+def identity_add_strict_lower_triangle(m: torch.Tensor):
     SIZE = m.size(-1)
     assert m.size(-2) == SIZE
     with torch.device(m.device):
         m = m.clone()
-        mask = torch.arange(SIZE).unsqueeze(1) <= torch.arange(SIZE)
+        # Keep only strict lower triangle (row > col), then add identity
+        mask = torch.arange(SIZE).unsqueeze(1) <= torch.arange(SIZE)  # upper triangle including diagonal
         m[:, mask] = 0.0
-        # m[mask.unsqueeze(0)] = 0.0
         m = m + torch.eye(SIZE).unsqueeze(0)
     return m

Then update the caller on line 446:

-            IKK = identity_add_strict_lower_diagonal(
+            IKK = identity_add_strict_lower_triangle(
                 beta_HS1 * torch.exp(Gamma_HSS) * matmul(k_HSK, k_HSK.transpose(-2, -1))
             )  # NOTE: beta scale row-wise

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tests/gdn/reference_delta_rule.py around lines 327 to 336 the function name
identity_add_strict_lower_diagonal is misleading because the implementation
zeros the upper triangle (mask row <= col) and then adds the identity, producing
strict lower triangle + identity; rename the function to
identity_add_strict_lower_triangle (or keep_strict_lower_plus_identity) and
update its single caller at line 446 to use the new name; ensure any
import/usage matches the new identifier and run tests to verify no other
references remain.

@yzh119 yzh119 merged commit aff0d3b into main Jan 3, 2026
4 checks passed
@yzh119 yzh119 deleted the feat/GDNAttention branch January 3, 2026 13:28
@vincentzed
Copy link
Contributor

vincentzed commented Jan 5, 2026

@guangyunh-nv Is there blackwell sm100 support in the plan? 🎉

@guangyunh-nv
Copy link
Collaborator Author

There is, but not by me. ETA is unclear tho.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Jan 22, 2026

Do we have plans to support bfloat16 for parameters like input_state, g, and beta? Currently, they seem limited to float32, but in vllm they are bf16.

Is this a deliberate choice?

if (alpha.has_value()) {
TensorView alpha_tensor = alpha.value();
TVM_FFI_ICHECK_EQ(alpha_tensor.dtype(), dl_float32);

TVM_FFI_ICHECK_EQ(beta_tensor.dtype(), dl_float32);

TVM_FFI_ICHECK_EQ(input_state.value().dtype(), dl_float32);

@yzh119
Copy link
Collaborator

yzh119 commented Jan 22, 2026

Hi @ZJY0516 I don't think that's a hard constraint, we can make it more flexible.

@guangyunh-nv
Copy link
Collaborator Author

It is minor change.

The problem is I am not sure if it is wise to store state in bf16. For prefilling, the internal state are all in FP32. This will make the long range accumulation much more accurate. But for decoding, every recurrent step accumluate on BF16 might impose a severe roundoff error.

I think this should be verified from algorithm side first.

yzh119 added a commit that referenced this pull request Feb 4, 2026
…ing. (#2422)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR implements these features:
1. accelerate hopper's gdn prefill compilation time by split compilation
2. fix the docstring of gdn prefill kernel, instead of [N, H, K, V], it
expects [N, H, V, K]

## 🔍 Related Issues

#2276 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

cc @guangyunh-nv 


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

* **New Features**
* Enhanced JIT module generation for GDN prefill kernels with
template-driven compilation and separate kernel instantiation.

* **Improvements**
* JIT specification now intelligently handles C++ standard flags,
applying defaults only when not already specified.

* **Documentation**
* Clarified final state memory layout description for GDN prefill
operations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
raayandhar pushed a commit to raayandhar/flashinfer that referenced this pull request Feb 5, 2026
…ing. (flashinfer-ai#2422)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR implements these features:
1. accelerate hopper's gdn prefill compilation time by split compilation
2. fix the docstring of gdn prefill kernel, instead of [N, H, K, V], it
expects [N, H, V, K]

## 🔍 Related Issues

flashinfer-ai#2276 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

cc @guangyunh-nv 


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

* **New Features**
* Enhanced JIT module generation for GDN prefill kernels with
template-driven compilation and separate kernel instantiation.

* **Improvements**
* JIT specification now intelligently handles C++ standard flags,
applying defaults only when not already specified.

* **Documentation**
* Clarified final state memory layout description for GDN prefill
operations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants