Skip to content

feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.#2679

Open
ameynaik-hub wants to merge 7 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn_bf16_improvements
Open

feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.#2679
ameynaik-hub wants to merge 7 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/gdn_bf16_improvements

Conversation

@ameynaik-hub
Copy link
Contributor

@ameynaik-hub ameynaik-hub commented Mar 3, 2026

Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage. Provides both T=1 (single token) and MTP (multi-token prediction) variants using a cooperative row approach.

Key design:

  • Each warp processes one V-row at a time (4 warps = 4 V-rows/iter)
  • cp.async pipeline with TILE_V=8 x TILE_K=128 tiles
  • H state stored as BF16 in memory, FP32 in registers for compute
  • ILP-optimized variant for large batch sizes (BS>=32)

Consolidated from separate cooprow file into canonical gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel. Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1 and MTP (T>1) when state is BF16 and K=V=128.

Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP):

  • BS=1-2: 1.09-1.35x speedup
  • BS=4-16: 1.24-2.21x speedup (biggest gains)
  • BS=32-512: 1.62-1.81x steady-state speedup
  • Peak: 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS

Cooprow BF16 State vs Optimized FP32 MTP Benchmark

GPU: B200
Config: Qwen3-Next (q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON)
Mode: cache_intermediate_states=ON, disable_state_update=True

Kernels compared:

  • Cooprow BF16: gated_delta_rule_bf16state_cooprow_mtp — cooperative row BF16 state kernel
  • Optimized FP32 MTP: gated_delta_rule_mtp — FP32 state kernel with ILP rows (1/2/4/8) + SMEM V caching

1. Cooprow BF16 State Kernel Time (us)

BS \ T 2 3 4 5 6 7 8
1 5.18 5.82 6.62 8.74 9.38 10.08 11.07
2 5.66 6.40 7.42 9.38 10.56 11.42 12.45
4 6.67 7.58 8.83 11.20 12.54 13.76 14.94
8 11.33 10.82 12.86 16.13 18.42 20.40 22.51
16 13.68 17.44 21.23 26.18 30.18 34.43 38.29
32 23.30 30.32 38.30 46.85 54.59 62.24 70.27
64 42.37 55.74 69.71 85.46 100.11 114.93 129.42
128 78.56 101.86 129.73 159.89 188.13 216.77 245.31
256 149.39 194.24 248.41 307.04 362.75 418.59 475.36
512 289.76 376.80 483.71 598.51 708.69 842.46 932.94

2. Optimized FP32 MTP Kernel Time (us)

BS \ T 2 3 4 5 6 7 8
1 5.66 7.04 8.34 9.79 10.93 12.54 13.85
2 6.61 8.26 9.95 11.58 13.22 14.91 16.78
4 9.50 11.94 14.08 16.26 18.64 28.77 23.65
8 14.08 17.60 28.22 24.82 29.20 33.09 37.73
16 22.96 27.65 47.02 43.84 64.59 73.97 84.75
32 40.48 55.62 64.32 76.29 89.78 105.57 119.10
64 68.45 92.64 119.04 142.56 167.07 194.13 218.30
128 129.63 176.99 222.93 270.22 317.36 369.17 419.15
256 250.81 341.90 432.94 524.08 617.77 721.64 822.65
512 492.59 671.39 854.70 1039.02 1232.91 1431.24 1691.08

3. Speedup (FP32 time / BF16 time, >1.0 = BF16 wins)

BS \ T 2 3 4 5 6 7 8
1 1.09 1.21 1.26 1.12 1.17 1.24 1.25
2 1.17 1.29 1.34 1.24 1.25 1.31 1.35
4 1.42 1.57 1.59 1.45 1.49 2.09 1.58
8 1.24 1.63 2.19 1.54 1.59 1.62 1.68
16 1.68 1.59 2.21 1.67 2.14 2.15 2.21
32 1.74 1.83 1.68 1.63 1.64 1.70 1.69
64 1.62 1.66 1.71 1.67 1.67 1.69 1.69
128 1.65 1.74 1.72 1.69 1.69 1.70 1.71
256 1.68 1.76 1.74 1.71 1.70 1.72 1.73
512 1.70 1.78 1.77 1.74 1.74 1.70 1.81

Summary

  • BS=1-2: 1.09-1.35x — cooprow BF16 wins but margins are smaller
  • BS=4-16: 1.24-2.21x — biggest gains; >2x spikes at BS=4-16 likely indicate tile-size transitions in the FP32 kernel
  • BS=32-512: 1.62-1.81x — consistent ~1.70-1.78x steady-state speedup
  • Peak TFLOPS: Cooprow BF16 reaches 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS

AI-assisted (Claude Code)

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Automatic BF16 State selection: single-step vs multi-token (MTP) chosen at runtime by sequence length.
    • Exposes additional BF16 State kernel variants for improved multi-token performance.
  • Refactor

    • Unified "BF16 State" naming across CLI, benchmarks, outputs, and help text.
    • Default state-update behavior for gated-delta operations changed.
  • Tests

    • Expanded coverage for single-step and MTP BF16 State paths.
  • Documentation

    • Updated CLI help, examples, benchmark legends, and run descriptions.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 3, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR splits and renames the BF16-state GDN decode backend into two variants (single-token BF16 state and BF16-state MTP for multi-token), updates kernel exports/availability flags, changes runtime dispatch to choose T==1 vs T>1 kernels, and adjusts benchmarks/tests/CLI to match the new APIs.

Changes

Cohort / File(s) Summary
Benchmarks / CLI
benchmarks/bench_gdn_decode.py
Renamed klast_bf16 → bf16_state, relaxed seq-len to T>=1, runtime dispatch now calls single-token BF16 for T==1 or MTP for T>1, added MTP wrapper args and updated output/labels.
Kernel Exports
flashinfer/gdn_kernels/__init__.py
Removed GatedDeltaRuleKernel export; added gated_delta_rule_mtp, gated_delta_rule_bf16state_cooprow, gated_delta_rule_bf16state_cooprow_mtp; updated imports, ImportError fallbacks, and __all__.
Core Decode Implementation
flashinfer/gdn_decode.py
Renamed availability flag to _GDN_DECODE_BF16_STATE_AVAILABLE; introduced _gated_delta_rule_bf16_state and _gated_delta_rule_bf16_state_mtp; dispatch selects T==1 vs T>1/pool paths; changed gated_delta_rule_mtp default disable_state_update to False.
Tests
tests/gdn/test_decode_delta_rule.py
Replaced klast_bf16 test references with bf16_state; added explicit T=1 and MTP (T>1) test branches and helpers; updated imports, assertions, messages, and smoke-test flows to new kernel naming.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Decode API
    participant Dispatch as Kernel Dispatch
    participant BF16 as BF16 State (T=1)
    participant MTP as BF16 State MTP (T>1)

    Client->>Dispatch: decode(T, dtype=bf16, args...)
    activate Dispatch
    alt T == 1
        Dispatch->>BF16: call _gated_delta_rule_bf16_state(...)
        activate BF16
        BF16-->>Dispatch: result
        deactivate BF16
    else T > 1
        Dispatch->>MTP: call _gated_delta_rule_bf16_state_mtp(...)
        activate MTP
        MTP-->>Dispatch: result
        deactivate MTP
    end
    deactivate Dispatch
    Dispatch-->>Client: decoded output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested labels

model: qwen3-next, run-ci

Suggested reviewers

  • bkryu
  • cyx-6
  • nvmbreughe
  • jimmyzho
  • jiahanc
  • yzh119

Poem

🐇
I hopped from KLAST to BF16 at night,
One-step for T==1, MTP for flight.
Kernels renamed, dispatch set true—
A carrot-coded shuffle, through and through. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 72.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding a BF16 state kernel with MTP support. It accurately reflects the primary objective of the pull request.
Description check ✅ Passed The description is mostly complete with technical details and benchmarks, but the required template sections (Description rationale, Related Issues) lack substantive content—only placeholders remain. The checklist shows pre-commit and tests pending.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the GDN decode functionality by integrating a new, highly optimized BF16 state kernel. This kernel, built with CuTe DSL, provides substantial performance gains across various batch sizes and sequence lengths, particularly for multi-token prediction. The changes involve refactoring existing kernel implementations, updating benchmarking infrastructure, and expanding test coverage to ensure correctness and efficiency of the new BF16 state processing.

Highlights

  • New BF16 State Kernel: Introduced a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage, supporting both single-token (T=1) and multi-token prediction (MTP) variants.
  • Performance Improvements: Achieved significant speedups: 1.09-1.35x for batch sizes 1-2, 1.24-2.21x for batch sizes 4-16 (biggest gains), and a consistent 1.62-1.81x steady-state speedup for batch sizes 32-512. Peak performance reached 13.8 TFLOPS compared to FP32's 7.9 TFLOPS.
  • Kernel Design: The new kernel employs a cooperative row approach, where each warp processes one V-row, utilizes a cp.async pipeline with TILE_V=8 x TILE_K=128 tiles, stores H state as BF16 in memory (FP32 in registers for compute), and includes an ILP-optimized variant for large batch sizes (BS>=32).
  • Refactored GDN Decode: The previous gdn_decode_klast_bf16_state kernel has been consolidated and replaced by the new gdn_decode_bf16_state.py which now handles both T=1 and MTP (T>1) scenarios when the state is BF16 and K=V=128.
  • Benchmarking and Testing: Updated benchmarks and tests to reflect the new BF16 state kernel, including dedicated tests for T=1 and MTP (T>=2) scenarios, and expanded seq_len support from T=1,2,3,4 to T>=1.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_gdn_decode.py
    • Renamed references from gdn_decode_klast_bf16_state to bf16_state for clarity and consistency.
    • Updated kernel import and availability check to GDN_DECODE_BF16_STATE_AVAILABLE.
    • Modified the wrapper function gdn_decode_klast_bf16_state_wrapper to gdn_decode_bf16_state_wrapper and added dispatch logic for T=1 and MTP (T>1) kernels.
    • Expanded seq_len support in benchmarks from fixed values (1,2,3,4) to any T>=1.
    • Adjusted benchmark output table headers and summary statistics to reflect the new kernel name.
  • flashinfer/gdn_kernels/init.py
    • Updated documentation comments to distinguish between T=1 and MTP (T>=1) BF16 state kernels.
    • Imported gated_delta_rule_mtp and added backward compatibility aliases for gated_delta_rule_bf16state_cooprow and gated_delta_rule_bf16state_cooprow_mtp.
    • Removed the export of GatedDeltaRuleKernel.
  • results_bf16_optimizations/cooprow_bf16_vs_optimized_fp32_mtp.md
    • Added a new markdown file detailing benchmark results comparing the cooperative row BF16 state kernel against an optimized FP32 MTP kernel.
  • tests/gdn/test_decode_delta_rule.py
    • Renamed kernel references and availability flags from gdn_decode_klast_bf16_state to gdn_decode_bf16_state.
    • Imported gdn_decode_bf16_state_mtp for multi-token prediction testing.
    • Updated the core test function _test_gdn_decode_klast_bf16_state_kernel to _test_gdn_decode_bf16_state_kernel, which now dispatches to the appropriate T=1 or MTP kernel based on seq_len.
    • Removed the seq_len constraint of T=1,2,3,4, allowing any T>=1.
    • Introduced new dedicated test functions _test_gdn_decode_bf16_state_t1_kernel and _test_gdn_decode_bf16_state_mtp_kernel for thorough validation of single-token and multi-token BF16 state kernels.
    • Updated the API dispatch test test_pretranspose_api_uses_gdn_decode_klast_bf16_state to test_pretranspose_api_uses_gdn_decode_bf16_state to verify correct kernel selection.
    • Adjusted smoke test output messages to reflect the new kernel naming.
Activity
  • The pull request introduces a new feature with significant performance improvements, as detailed in the benchmark results.
  • New benchmark results have been added in a dedicated markdown file to showcase the performance gains.
  • Extensive unit tests have been added and updated to cover the new BF16 state kernel for both single-token and multi-token prediction scenarios, ensuring correctness and reliability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance BF16 state kernel for GDN decode, supporting both single-token (T=1) and multi-token prediction (MTP), which demonstrates significant performance improvements. The changes are well-structured, including updates to benchmarks and the addition of comprehensive tests.

My review has identified a bug in one of the new tests that needs to be addressed to ensure correctness. Additionally, I've provided a couple of suggestions to refactor duplicated code in both the benchmark and test files, which will improve the overall maintainability of the codebase.

Comment on lines +1862 to +1891
if T == 1:
return gdn_decode_bf16_state(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale,
)
else:
return gdn_decode_bf16_state_mtp(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calls to gdn_decode_bf16_state and gdn_decode_bf16_state_mtp share the same set of arguments. This code can be refactored to reduce duplication and improve maintainability by selecting the kernel function first and then calling it with a shared set of keyword arguments.

    T = q.shape[1]
    kernel_fn = gdn_decode_bf16_state if T == 1 else gdn_decode_bf16_state_mtp
    return kernel_fn(
        A_log=A_log,
        a=a,
        dt_bias=dt_bias,
        softplus_beta=softplus_beta,
        softplus_threshold=softplus_threshold,
        q=q,
        k=k,
        v=v,
        b=b,
        initial_state_source=state,
        use_qk_l2norm_in_kernel=use_qk_l2norm,
        scale=scale,
    )

Comment on lines +929 to +940
def _test_gdn_decode_bf16_state_t1_kernel(
dtype: str,
batch_size: int,
num_q_heads: int,
num_k_heads: int,
num_v_heads: int,
head_size: int,
scale: float,
alpha: bool,
beta: bool,
seed: int | None = None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication for test data generation across _test_gdn_decode_bf16_state_kernel, _test_gdn_decode_bf16_state_t1_kernel, and _test_gdn_decode_bf16_state_mtp_kernel. To improve maintainability, consider refactoring the tensor creation logic into a shared helper function or a pytest fixture. This would make the tests cleaner and easier to manage.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
benchmarks/bench_gdn_decode.py (3)

2788-2799: ⚠️ Potential issue | 🟠 Major

Restore --compare routing for decode versions in main().

For non-MTP paths, args.compare is currently ignored, so single-layout comparison mode is no longer reachable from the CLI.

Suggested fix
-    if args.version == "mtp":
+    if args.version == "mtp":
         # MTP mode: use comparison or flashinfer-only
         if args.compare:
             run_comparison_benchmark(args, dtype, use_qk_l2norm)
         else:
             run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
     elif args.version == "bf16_state":
         # BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP
         run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm)
     else:
-        # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state)
-        run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
+        # Decode mode: honor --compare flag
+        if args.compare:
+            run_comparison_benchmark(args, dtype, use_qk_l2norm)
+        else:
+            run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2788 - 2799, The non-MTP
branches ignore args.compare so the CLI --compare mode is unreachable; update
the bf16_state and else branches in main(): for the "bf16_state" branch, if
args.compare call run_comparison_benchmark(args, dtype, use_qk_l2norm) else call
run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if
args.compare call run_comparison_benchmark(...) else call
run_all_layouts_benchmark(...). Use the same argument list (args, dtype,
use_qk_l2norm) when invoking the functions run_comparison_benchmark,
run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare
behaves consistently across versions.

1845-1891: ⚠️ Potential issue | 🟠 Major

Forward preallocated tensors in BF16 MTP wrapper to avoid benchmark allocations.

The wrapper accepts output and ignores it in both T=1 and T>1 paths. For T>1, the MTP kernel supports both output and initial_state_indices parameters, but the wrapper doesn't forward them. This causes allocations during benchmark timing, skewing measurements.

Add initial_state_indices parameter to wrapper signature and forward both parameters to gdn_decode_bf16_state_mtp(). Update bench_gdn_decode_bf16_state() to create and pass initial_state_indices when T>1.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1845 - 1891, The wrapper around
gdn_decode_bf16_state currently ignores the preallocated output and doesn't
accept or forward initial_state_indices for the multi-timestep path, causing
benchmark allocations; update the wrapper signature to add an
initial_state_indices: torch.Tensor (or optional) parameter, and when T>1
forward both output and initial_state_indices into gdn_decode_bf16_state_mtp by
passing the kernel's output=output and
initial_state_indices=initial_state_indices args (keep the T==1 call unchanged),
and update bench_gdn_decode_bf16_state() to allocate/create
initial_state_indices for the T>1 case and pass it through to the wrapper so no
new allocations occur inside the kernel call.

2049-2068: ⚠️ Potential issue | 🟠 Major

Use float32 dt_bias for BF16-state benchmark calls.

The gdn_decode_bf16_state kernel specification requires dt_bias as [HV] float32, as confirmed by kernel docstrings and test code comments. These benchmark paths currently create it with dtype (typically BF16), which diverges from the intended interface and may benchmark a different numerical path than intended.

Suggested fix
-    dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
+    dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")

Also applies to: 2256-2260

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2049 - 2068, The benchmark is
passing dt_bias with the wrong dtype for the BF16-state kernel; update the
dt_bias tensor construction used in the gdn_decode_bf16_state benchmark calls so
dt_bias is created as float32 (torch.float32) with shape [num_sab_heads *
head_size] or [HV] as required, then pass that float32 dt_bias into
gdn_decode_bf16_state_wrapper (and the second BF16-state call around the other
occurrence). Locate the dt_bias variable used in the
gdn_decode_bf16_state_wrapper invocations and change its dtype to torch.float32
while keeping device and shape identical.
tests/gdn/test_decode_delta_rule.py (1)

824-831: ⚠️ Potential issue | 🟠 Major

Fix kernel dispatch for seq_len > 1 in pretranspose API test.

The test parametrizes seq_len with [1, 2, 3, 4] (line 824), but the direct kernel call at line 902 always invokes gdn_decode_bf16_state, which is the single-token (T=1) kernel. For seq_len > 1, the test should dispatch to gdn_decode_bf16_state_mtp and pass the initial_state_indices parameter. Currently this breaks verification of the multi-token path.

Suggested fix
-    # Direct improved kernel
-    out_direct = gdn_decode_bf16_state(
+    # Direct kernel: T=1 uses single-token path, T>1 uses MTP path
+    direct_kernel = (
+        gdn_decode_bf16_state if seq_len == 1 else gdn_decode_bf16_state_mtp
+    )
+    direct_kwargs = dict(
         A_log=A_log,
         a=a,
         dt_bias=dt_bias,
         softplus_beta=1.0,
         softplus_threshold=20.0,
         q=q,
         k=k,
         v=v,
         b=b_tensor,
         initial_state_source=state_direct,
         use_qk_l2norm_in_kernel=True,
         scale=scale,
-    )
+    )
+    if seq_len > 1:
+        direct_kwargs["initial_state_indices"] = torch.arange(
+            batch_size, dtype=torch.int32, device=device
+        )
+    out_direct = direct_kernel(**direct_kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 824 - 831, The test
test_pretranspose_api_uses_gdn_decode_bf16_state incorrectly always calls the
single-token kernel gdn_decode_bf16_state; update the dispatch so when seq_len >
1 it calls gdn_decode_bf16_state_mtp and supplies the initial_state_indices
argument (preserving the existing call for seq_len == 1). Locate the direct
kernel invocation around the test body (the call to gdn_decode_bf16_state) and
add a conditional: if seq_len == 1 keep the existing call, else call
gdn_decode_bf16_state_mtp with the same parameters plus initial_state_indices so
the multi-token verification path is exercised.
🧹 Nitpick comments (2)
tests/gdn/test_decode_delta_rule.py (2)

1148-1179: Validate cached intermediate states in the BF16 MTP test path.

When cache_intermediate_states=True, the test currently doesn’t verify buffer contents. Adding a reference comparison would protect the new intermediate-caching path from silent regressions.

Also applies to: 1218-1225

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1148 - 1179, Add an
explicit verification of the cached intermediate states when
cache_intermediate_states=True: after calling gdn_decode_bf16_state_mtp with
intermediate_states_buffer provided, run a reference call that produces the
expected intermediate states (e.g., call gdn_decode_bf16_state_mtp or the
float32 equivalent with caching disabled/producing a reference buffer) and
compare intermediate_states_buffer to that reference using an appropriate
numeric tolerance and dtype/device conversion (use torch.testing.assert_allclose
or equivalent) so the BF16 MTP test path actually validates buffer contents;
apply the same check in the other test location that mirrors this logic (the
block around the second call referenced in the comment).

634-637: Remove or use the unused alpha helper parameter.

alpha is currently unused in the BF16 helper tests, which makes the test surface a bit misleading.

Also applies to: 937-940

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 634 - 637, The helper
function in tests/gdn/test_decode_delta_rule.py declares an unused parameter
alpha alongside beta and seed; either remove alpha from the helper signature and
all its call sites in that file (including the BF16 helper usages) or modify the
BF16 helper tests to actually use the alpha parameter, ensuring you update the
function signature and every invocation consistently; reference the parameter
names alpha, beta, seed when making the change so you catch all occurrences (the
same unused-alpha issue also appears later in the file around the BF16 helper
usages).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2788-2799: The non-MTP branches ignore args.compare so the CLI
--compare mode is unreachable; update the bf16_state and else branches in
main(): for the "bf16_state" branch, if args.compare call
run_comparison_benchmark(args, dtype, use_qk_l2norm) else call
run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if
args.compare call run_comparison_benchmark(...) else call
run_all_layouts_benchmark(...). Use the same argument list (args, dtype,
use_qk_l2norm) when invoking the functions run_comparison_benchmark,
run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare
behaves consistently across versions.
- Around line 1845-1891: The wrapper around gdn_decode_bf16_state currently
ignores the preallocated output and doesn't accept or forward
initial_state_indices for the multi-timestep path, causing benchmark
allocations; update the wrapper signature to add an initial_state_indices:
torch.Tensor (or optional) parameter, and when T>1 forward both output and
initial_state_indices into gdn_decode_bf16_state_mtp by passing the kernel's
output=output and initial_state_indices=initial_state_indices args (keep the
T==1 call unchanged), and update bench_gdn_decode_bf16_state() to
allocate/create initial_state_indices for the T>1 case and pass it through to
the wrapper so no new allocations occur inside the kernel call.
- Around line 2049-2068: The benchmark is passing dt_bias with the wrong dtype
for the BF16-state kernel; update the dt_bias tensor construction used in the
gdn_decode_bf16_state benchmark calls so dt_bias is created as float32
(torch.float32) with shape [num_sab_heads * head_size] or [HV] as required, then
pass that float32 dt_bias into gdn_decode_bf16_state_wrapper (and the second
BF16-state call around the other occurrence). Locate the dt_bias variable used
in the gdn_decode_bf16_state_wrapper invocations and change its dtype to
torch.float32 while keeping device and shape identical.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 824-831: The test test_pretranspose_api_uses_gdn_decode_bf16_state
incorrectly always calls the single-token kernel gdn_decode_bf16_state; update
the dispatch so when seq_len > 1 it calls gdn_decode_bf16_state_mtp and supplies
the initial_state_indices argument (preserving the existing call for seq_len ==
1). Locate the direct kernel invocation around the test body (the call to
gdn_decode_bf16_state) and add a conditional: if seq_len == 1 keep the existing
call, else call gdn_decode_bf16_state_mtp with the same parameters plus
initial_state_indices so the multi-token verification path is exercised.

---

Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1148-1179: Add an explicit verification of the cached intermediate
states when cache_intermediate_states=True: after calling
gdn_decode_bf16_state_mtp with intermediate_states_buffer provided, run a
reference call that produces the expected intermediate states (e.g., call
gdn_decode_bf16_state_mtp or the float32 equivalent with caching
disabled/producing a reference buffer) and compare intermediate_states_buffer to
that reference using an appropriate numeric tolerance and dtype/device
conversion (use torch.testing.assert_allclose or equivalent) so the BF16 MTP
test path actually validates buffer contents; apply the same check in the other
test location that mirrors this logic (the block around the second call
referenced in the comment).
- Around line 634-637: The helper function in
tests/gdn/test_decode_delta_rule.py declares an unused parameter alpha alongside
beta and seed; either remove alpha from the helper signature and all its call
sites in that file (including the BF16 helper usages) or modify the BF16 helper
tests to actually use the alpha parameter, ensuring you update the function
signature and every invocation consistently; reference the parameter names
alpha, beta, seed when making the change so you catch all occurrences (the same
unused-alpha issue also appears later in the file around the BF16 helper
usages).

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e08e8f3 and d470d1d.

📒 Files selected for processing (6)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • results_bf16_optimizations/cooprow_bf16_vs_optimized_fp32_mtp.md
  • tests/gdn/test_decode_delta_rule.py

Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden
state storage. Provides both T=1 (single token) and MTP (multi-token
prediction) variants using a cooperative row approach.

Key design:
- Each warp processes one V-row at a time (4 warps = 4 V-rows/iter)
- cp.async pipeline with TILE_V=8 x TILE_K=128 tiles
- H state stored as BF16 in memory, FP32 in registers for compute
- ILP-optimized variant for large batch sizes (BS>=32)

Consolidated from separate cooprow file into canonical
gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel.
Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1
and MTP (T>1) when state is BF16 and K=V=128.

Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP):
- BS=1-2:   1.09-1.35x speedup
- BS=4-16:  1.24-2.21x speedup (biggest gains)
- BS=32-512: 1.62-1.81x steady-state speedup
- Peak: 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub force-pushed the ameyn/gdn_bf16_improvements branch from d470d1d to a906f21 Compare March 3, 2026 20:39
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/bench_gdn_decode.py (1)

1928-1931: ⚠️ Potential issue | 🟠 Major

Use float32 dt_bias in BF16-state benchmark paths.

BF16-state tests in this PR consistently feed dt_bias as float32, but these benchmark paths generate dt_bias with dtype (typically bf16/fp16). That can trigger dtype assertions/casts and distort or fail BF16-state benchmarking.

Suggested patch
-    dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
+    dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")
-    dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
+    dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")

Also applies to: 2257-2260

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1928 - 1931, The benchmark
creates dt_bias with the generic dtype (bf16/fp16) causing dtype mismatches in
BF16-state paths; change the dt_bias creation to explicitly use torch.float32
(e.g., dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda"))
wherever dt_bias is constructed alongside A_log, a, b (the dt_bias variable in
the block with A_log, a, b) and apply the same fix to the other identical
occurrence later in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Line 2420: The function gated_delta_rule_mtp now sets the parameter
disable_state_update default to False which silently changes behavior; update it
to preserve prior behavior by restoring disable_state_update: bool = True in the
gated_delta_rule_mtp signature (or if the change is intentional, update the
function's docstring and any callers to document and adopt
disable_state_update=False) and ensure the docstring text for
gated_delta_rule_mtp describing the default matches the new default to avoid
mismatch.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 623-635: The helper function _test_gdn_decode_bf16_state_kernel
declares an unused parameter alpha which triggers ARG001; rename alpha to _alpha
(or remove it) to mark it as intentionally unused and silence the linter, and
apply the same rename to the other BF16-state helper(s) in this file that also
declare an unused alpha parameter so all occurrences are fixed.

---

Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1928-1931: The benchmark creates dt_bias with the generic dtype
(bf16/fp16) causing dtype mismatches in BF16-state paths; change the dt_bias
creation to explicitly use torch.float32 (e.g., dt_bias =
torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")) wherever dt_bias
is constructed alongside A_log, a, b (the dt_bias variable in the block with
A_log, a, b) and apply the same fix to the other identical occurrence later in
the file.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d470d1d and a906f21.

📒 Files selected for processing (5)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

Comment on lines +623 to 635
def _test_gdn_decode_bf16_state_kernel(
dtype: str,
batch_size: int,
num_q_heads: int,
num_k_heads: int,
num_v_heads: int,
head_size: int,
seq_len: int, # T=1,2,3,4
seq_len: int,
scale: float,
alpha: bool,
beta: bool,
seed: int | None = None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Resolve unused alpha parameters in BF16-state helper tests.

alpha is not read in these helper bodies, and Ruff is already flagging it (ARG001). Rename to _alpha (or remove) to avoid lint noise/failures.

Suggested patch
 def _test_gdn_decode_bf16_state_kernel(
@@
-    alpha: bool,
+    _alpha: bool,
@@
 def _test_gdn_decode_bf16_state_t1_kernel(
@@
-    alpha: bool,
+    _alpha: bool,

Also applies to: 928-939

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 632-632: Unused function argument: alpha

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 623 - 635, The helper
function _test_gdn_decode_bf16_state_kernel declares an unused parameter alpha
which triggers ARG001; rename alpha to _alpha (or remove it) to mark it as
intentionally unused and silence the linter, and apply the same rename to the
other BF16-state helper(s) in this file that also declare an unused alpha
parameter so all occurrences are fixed.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
benchmarks/bench_gdn_decode.py (1)

1878-1891: ⚠️ Potential issue | 🟠 Major

Benchmark the MTP kernel with the preallocated output tensor.

This timed branch still omits output=output, so each iteration benchmarks extra allocation/copy work instead of just the kernel path.

Suggested patch
     else:
         return gdn_decode_bf16_state_mtp(
             A_log=A_log,
             a=a,
             dt_bias=dt_bias,
             softplus_beta=softplus_beta,
             softplus_threshold=softplus_threshold,
             q=q,
             k=k,
             v=v,
             b=b,
             initial_state_source=state,
             use_qk_l2norm_in_kernel=use_qk_l2norm,
             scale=scale,
+            output=output,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1878 - 1891, The benchmark call
to gdn_decode_bf16_state_mtp is missing the preallocated output tensor, causing
allocations each iteration; update the call site where
gdn_decode_bf16_state_mtp(...) is invoked (the function call shown) to pass
output=output so the preallocated tensor is used, keeping all other named params
(A_log, a, dt_bias, softplus_beta, softplus_threshold, q, k, v, b,
initial_state_source=state, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale)
unchanged.
flashinfer/gdn_decode.py (1)

509-546: ⚠️ Potential issue | 🟠 Major

Preserve the previous read-only default for disable_state_update.

Changing the default to False means existing callers that omit this argument now mutate initial_state. That is a silent public-API behavior change.

Suggested patch
-    disable_state_update: bool = False,
+    disable_state_update: bool = True,
@@
         disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``False``.
+            If True, the initial state is not updated. Default: ``True``.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 509 - 546, The default for the
disable_state_update parameter was changed and must be restored to preserve
read-only behavior; in the function gdn_decode (the Gated Delta Rule MTP Kernel
that accepts disable_state_update: bool), revert the default value of
disable_state_update back to True, update any related docs/parameter description
in the function docstring to match (i.e., indicate default True and that
omitting the argument prevents mutation of initial_state), and ensure any
downstream callers/tests relying on the previous default continue to behave the
same.
🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)

1625-1702: Actually assert the cached BF16 intermediate states.

When cache_intermediate_states=True, this helper never checks intermediate_states_buffer. A broken beyond-T>4 cache-write path would still pass, which leaves the new feature effectively unverified.

Suggested test extension
     # Reference: step through tokens with bf16 state
     ref_state = input_state_ref_bf16.clone()
     ref_outputs = []
+    ref_intermediate_states = []

     for t in range(seq_len):
         ref_o_t, ref_state = decode_delta_rule(
             q[:, t].float(),
             k[:, t].float(),
@@
             use_l2_norm=True,
             state_dtype=torch.bfloat16,
         )
         ref_outputs.append(ref_o_t)
+        if cache_intermediate_states:
+            ref_intermediate_states.append(
+                ref_state.transpose(-2, -1).contiguous().clone()
+            )

     ref_o = torch.stack(ref_outputs, dim=1).to(dtype_torch)
@@
     torch.testing.assert_close(
         our_o.float(),
         ref_o.float(),
         atol=atol_o,
         rtol=rtol_o,
         msg=f"Output mismatch for MTP BF16 state kernel (B={batch_size}, T={seq_len})",
     )
+
+    if cache_intermediate_states and intermediate_states_buffer is not None:
+        ref_intermediate = torch.stack(ref_intermediate_states, dim=1)
+        torch.testing.assert_close(
+            intermediate_states_buffer.float(),
+            ref_intermediate.float(),
+            atol=0.02,
+            rtol=0.01,
+            msg=(
+                f"Intermediate-state cache mismatch for MTP BF16 state kernel "
+                f"(B={batch_size}, T={seq_len})"
+            ),
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1625 - 1702, The test never
asserts intermediate_states_buffer when cache_intermediate_states=True, so add
assertions that the buffer returned/modified by gdn_decode_bf16_state_mtp
(intermediate_states_buffer) matches the per-step BF16 intermediate states
computed during the reference loop using decode_delta_rule (collect the
per-token intermediate state values while building ref_outputs), comparing
shapes/dtypes and using tight atol/rtol similar to output checks; reference
intermediate_states_buffer, gdn_decode_bf16_state_mtp, decode_delta_rule,
our_state and input_state_kernel to locate where to capture and compare the
cached states and ensure the cache-write path is actually verified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 206-212: The BF16 fast-path selection (use_bf16_state) must be
guarded against padded/negative pool indices: extend the predicate that sets
use_bf16_state (which currently uses _GDN_DECODE_BF16_STATE_AVAILABLE,
state_dtype, K and V) to also verify that initial_state_indices either is None
or contains no negative entries (e.g. check initial_state_indices is not
provided OR torch.all(initial_state_indices >= 0) is true) before enabling the
BF16 backend, because the BF16 pooled-state backend does not implement
negative-index semantics and will mis-handle padding slots.
- Around line 235-250: The BF16 MTP fast path currently ignores the caller's
preallocated output and allocates a new tensor then copies back; modify the call
to _gated_delta_rule_bf16_state_mtp so it writes directly into the caller's
output buffer (pass the caller's output as an explicit output argument and
ensure shape/dtype/stride compatibility), remove the downstream copy-from-kernel
back into `output`, and preserve existing flags (use_pool,
initial_state/initial_state_indices, use_qk_l2norm, scale_val) when forwarding
to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place into the
provided `output` buffer.

---

Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1878-1891: The benchmark call to gdn_decode_bf16_state_mtp is
missing the preallocated output tensor, causing allocations each iteration;
update the call site where gdn_decode_bf16_state_mtp(...) is invoked (the
function call shown) to pass output=output so the preallocated tensor is used,
keeping all other named params (A_log, a, dt_bias, softplus_beta,
softplus_threshold, q, k, v, b, initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale) unchanged.

In `@flashinfer/gdn_decode.py`:
- Around line 509-546: The default for the disable_state_update parameter was
changed and must be restored to preserve read-only behavior; in the function
gdn_decode (the Gated Delta Rule MTP Kernel that accepts disable_state_update:
bool), revert the default value of disable_state_update back to True, update any
related docs/parameter description in the function docstring to match (i.e.,
indicate default True and that omitting the argument prevents mutation of
initial_state), and ensure any downstream callers/tests relying on the previous
default continue to behave the same.

---

Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1625-1702: The test never asserts intermediate_states_buffer when
cache_intermediate_states=True, so add assertions that the buffer
returned/modified by gdn_decode_bf16_state_mtp (intermediate_states_buffer)
matches the per-step BF16 intermediate states computed during the reference loop
using decode_delta_rule (collect the per-token intermediate state values while
building ref_outputs), comparing shapes/dtypes and using tight atol/rtol similar
to output checks; reference intermediate_states_buffer,
gdn_decode_bf16_state_mtp, decode_delta_rule, our_state and input_state_kernel
to locate where to capture and compare the cached states and ensure the
cache-write path is actually verified.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cb21f102-081e-40a1-854c-620a96b8d9d3

📥 Commits

Reviewing files that changed from the base of the PR and between a906f21 and 2ba157b.

📒 Files selected for processing (4)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • tests/gdn/test_decode_delta_rule.py

Comment on lines +206 to +212
use_bf16_state = (
_GDN_DECODE_BF16_STATE_AVAILABLE
and state_dtype == torch.bfloat16
and T in (1, 2, 3, 4)
and K == 128
and V == 128
)
if use_gdn_decode_klast_bf16_state:
if use_bf16_state:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t route negative pool indices into the BF16 fast path.

This predicate still selects the BF16 backend for pooled bf16 state even when initial_state_indices contains padding slots (-1). That backend does not implement the negative-index semantics, so these calls can read/write the wrong pool row instead of honoring padding.

Suggested guard
     use_bf16_state = (
         _GDN_DECODE_BF16_STATE_AVAILABLE
         and state_dtype == torch.bfloat16
         and K == 128
         and V == 128
     )
+    if use_bf16_state and use_pool and (initial_state_indices < 0).any().item():
+        raise ValueError(
+            "Negative initial_state_indices are only supported with float32 state; "
+            "the BF16 fast path does not support padding slots."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 206 - 212, The BF16 fast-path
selection (use_bf16_state) must be guarded against padded/negative pool indices:
extend the predicate that sets use_bf16_state (which currently uses
_GDN_DECODE_BF16_STATE_AVAILABLE, state_dtype, K and V) to also verify that
initial_state_indices either is None or contains no negative entries (e.g. check
initial_state_indices is not provided OR torch.all(initial_state_indices >= 0)
is true) before enabling the BF16 backend, because the BF16 pooled-state backend
does not implement negative-index semantics and will mis-handle padding slots.

Comment on lines +235 to +250
# MTP kernel supports T>=1 and pool+indices
out = _gated_delta_rule_bf16_state_mtp(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=1.0,
softplus_threshold=20.0,
q=q,
k=k,
v=v,
b=b,
initial_state_source=initial_state if use_pool else state,
initial_state_indices=initial_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass the caller’s output buffer into the BF16 MTP backend.

The new T>1 / pool fast path always allocates a fresh output tensor and then copies it back into output. That defeats preallocation and adds avoidable device traffic in the hot decode loop.

Suggested patch
         else:
             # MTP kernel supports T>=1 and pool+indices
             out = _gated_delta_rule_bf16_state_mtp(
                 A_log=A_log,
                 a=a,
                 dt_bias=dt_bias,
                 softplus_beta=1.0,
                 softplus_threshold=20.0,
                 q=q,
                 k=k,
                 v=v,
                 b=b,
                 initial_state_source=initial_state if use_pool else state,
                 initial_state_indices=initial_state_indices,
                 use_qk_l2norm_in_kernel=use_qk_l2norm,
                 scale=scale_val,
+                output=output,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 235 - 250, The BF16 MTP fast path
currently ignores the caller's preallocated output and allocates a new tensor
then copies back; modify the call to _gated_delta_rule_bf16_state_mtp so it
writes directly into the caller's output buffer (pass the caller's output as an
explicit output argument and ensure shape/dtype/stride compatibility), remove
the downstream copy-from-kernel back into `output`, and preserve existing flags
(use_pool, initial_state/initial_state_indices, use_qk_l2norm, scale_val) when
forwarding to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place
into the provided `output` buffer.

Resolve 4 conflicting files after main's major refactor of gdn_decode.py
from a 2643-line monolith into a 645-line API layer with kernel code
extracted into flashinfer/gdn_kernels/ submodules.

Conflict resolutions:
- gdn_decode.py: Accept main's refactored API layer, port feature branch's
  BF16 state import renames, updated dispatch logic (T=1 no-pool → bf16_state,
  else → bf16_state_mtp), and disable_state_update default docstring fix.
- gdn_kernels/__init__.py: Merge both - keep main's expanded exports and add
  feature branch's MTP exports and backward-compat aliases.
- gdn_kernels/gdn_decode_bf16_state.py: Accept feature branch's complete
  rewrite with coop-row kernel approach + MTP kernel.
- tests/gdn/test_decode_delta_rule.py: Start from main's version (pool+indices
  tests, negative indices tests), apply feature branch's renames, dispatch
  split, new T=1 and MTP test functions, remove CI skip marker.

AI-assisted merge resolution.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub force-pushed the ameyn/gdn_bf16_improvements branch from 46306ff to 421334a Compare March 18, 2026 19:54
@ameynaik-hub
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/bench_gdn_decode.py (1)

2281-2306: ⚠️ Potential issue | 🟠 Major

Don’t charge cache traffic in the no-cache BF16 benchmark.

gdn_decode_bytes() adds [B, T, HV, V, K] intermediate-state writes for every seq_len > 1, but this benchmark never allocates or forwards intermediate_states_buffer. The reported TB/s is therefore inflated for every BF16 MTP run.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2281 - 2306, The benchmark is
incorrectly charging intermediate-state write traffic for BF16 no-cache runs;
update the gdn_decode_bytes call in the benchmark so it does not count those
writes (e.g., set disable_state_update=True or set state_dtype_bytes=0) when the
intermediate_states_buffer is not allocated/forwarded; change the call that
currently uses disable_state_update=False and state_dtype_bytes=2 to instead
reflect no state updates for gdn_decode_bf16_state (e.g.,
disable_state_update=True or state_dtype_bytes=0) so TB/s is not inflated.
♻️ Duplicate comments (4)
flashinfer/gdn_decode.py (3)

509-510: ⚠️ Potential issue | 🟠 Major

Don’t silently flip gated_delta_rule_mtp() to mutating by default.

disable_state_update=False still changes the public default from verify-mode semantics to in-place state mutation for callers that omit the flag. The docstring now matches, but the API break remains.

Suggested patch
-    disable_state_update: bool = False,
+    disable_state_update: bool = True,
@@
-            If True, the initial state is not updated. Default: ``False``.
+            If True, the initial state is not updated. Default: ``True``.

Also applies to: 545-546

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 509 - 510, The function signature(s)
for gated_delta_rule_mtp (and the other occurrence at lines ~545-546) changed
the default to disable_state_update=False which silently makes the API mutating
by default; revert the public API to preserve verify-mode by setting
disable_state_update=True in the function signature(s) (or otherwise ensure the
default remains non-mutating), and keep the docstring consistent with that
non-mutating default so callers that omit the flag keep previous semantics.

236-250: ⚠️ Potential issue | 🟠 Major

Forward the caller’s output buffer into the BF16 MTP backend.

This path still lets _gated_delta_rule_bf16_state_mtp() allocate a fresh output and then copies it back into output, so preallocation does not actually remove the allocation/device copy on the hot path.

Suggested patch
             out = _gated_delta_rule_bf16_state_mtp(
                 A_log=A_log,
                 a=a,
                 dt_bias=dt_bias,
                 softplus_beta=1.0,
                 softplus_threshold=20.0,
                 q=q,
                 k=k,
                 v=v,
                 b=b,
                 initial_state_source=initial_state if use_pool else state,
                 initial_state_indices=initial_state_indices,
                 use_qk_l2norm_in_kernel=use_qk_l2norm,
                 scale=scale_val,
+                output=output,
             )
         output_provided = output is not None
         target_dtype = output.dtype if output_provided else q.dtype
-        if output is not None:
+        if output is not None and out is not output:
             output.copy_(out)
         else:
             output = out

Also applies to: 251-255

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 236 - 250, The call site currently
lets _gated_delta_rule_bf16_state_mtp allocate and return a new tensor which is
later copied into output; instead, forward the caller's preallocated output
buffer into the BF16 MTP backend by passing the existing output tensor as an
explicit argument (e.g., add an output=output or out=output parameter) when
calling _gated_delta_rule_bf16_state_mtp (and update the second duplicate call
at lines 251-255 similarly); ensure the backend function signature (
_gated_delta_rule_bf16_state_mtp ) accepts this output/out parameter and writes
into it in-place so no new allocation or device copy occurs on the hot path.

206-212: ⚠️ Potential issue | 🟠 Major

Guard the BF16 fast path against padded pool indices.

initial_state_indices < 0 still enables the BF16 backend here, but the BF16 pooled path does not implement padding-slot semantics. A bf16 pooled call with -1 can therefore read or update the wrong pool row instead of producing padded output.

Suggested guard
     use_bf16_state = (
         _GDN_DECODE_BF16_STATE_AVAILABLE
         and state_dtype == torch.bfloat16
         and K == 128
         and V == 128
     )
+    if use_bf16_state and use_pool and (initial_state_indices < 0).any().item():
+        raise ValueError(
+            "Negative initial_state_indices are only supported with float32 state; "
+            "the BF16 fast path does not support padding slots."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 206 - 212, The BF16 fast path
(computed into use_bf16_state) must be disabled when any pooled indices are
padding; update the condition that sets use_bf16_state (which currently checks
_GDN_DECODE_BF16_STATE_AVAILABLE, state_dtype, K == 128, V == 128) to also
verify that initial_state_indices contains no negative values (e.g., ensure
initial_state_indices.min() >= 0 or torch.all(initial_state_indices >= 0))
before enabling the BF16 path so the bf16 pooled implementation is never used
when padding slots (negative indices) are present.
benchmarks/bench_gdn_decode.py (1)

1878-1891: ⚠️ Potential issue | 🟠 Major

Use the preallocated output buffer in the BF16-state MTP benchmark path.

The wrapper still ignores its output argument for T > 1, so every timed iteration includes an avoidable allocation/copy and the BF16-state numbers are not directly comparable to the other kernels.

Suggested patch
         return gdn_decode_bf16_state_mtp(
             A_log=A_log,
             a=a,
             dt_bias=dt_bias,
             softplus_beta=softplus_beta,
             softplus_threshold=softplus_threshold,
             q=q,
             k=k,
             v=v,
             b=b,
             initial_state_source=state,
             use_qk_l2norm_in_kernel=use_qk_l2norm,
             scale=scale,
+            output=output,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1878 - 1891, The wrapper for
BF16-state MTP (gdn_decode_bf16_state_mtp) currently ignores the provided output
buffer for T>1 and allocates/copies inside each iteration; modify the wrapper so
it writes into the passed-in output buffer instead of creating a new one for
multi-step runs: detect the T>1 path in the wrapper that calls
gdn_decode_bf16_state_mtp (the call shown with
A_log,a,dt_bias,...,initial_state_source=state,use_qk_l2norm_in_kernel,scale)
and forward the provided output argument into the underlying implementation or
reuse it for accumulating results, ensuring no per-iteration allocation/copy
happens and the shape/dtype semantics match the existing single-step case.
🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)

1625-1653: Actually assert the BF16 MTP intermediate-state cache.

When cache_intermediate_states=True, this helper allocates and passes intermediate_states_buffer, but never compares its contents with the step-by-step reference. A regression in the new caching path would still pass as long as the outputs stay correct.

Also applies to: 1660-1680

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 1625 - 1653, The test
allocates intermediate_states_buffer when cache_intermediate_states is True but
never asserts its contents; update the test after calling
gdn_decode_bf16_state_mtp (the call that sets intermediate_states_buffer) to
compare intermediate_states_buffer against the step-by-step reference buffer
used for the MTP/disable_state_update verification (the same reference used for
outputs), e.g., fetch the expected intermediate states produced by the
stepwise/reference decoder and assert equality (or close with appropriate dtype
tolerance) against intermediate_states_buffer; ensure you handle None when
cache_intermediate_states is False and reuse variables like
intermediate_states_buffer, our_state, and the reference stepwise buffer to
locate the correct data to compare.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2802-2807: The fallback branch currently always calls
run_all_layouts_benchmark(), ignoring args.compare; adjust the else logic so
that when args.compare is true and args.version is "pretranspose" or
"nontranspose" you call run_comparison_benchmark(args, dtype, use_qk_l2norm)
instead of run_all_layouts_benchmark(), otherwise keep calling
run_all_layouts_benchmark(); keep the existing bf16_state branch that calls
run_gdn_decode_bf16_state_benchmark intact.

---

Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2281-2306: The benchmark is incorrectly charging
intermediate-state write traffic for BF16 no-cache runs; update the
gdn_decode_bytes call in the benchmark so it does not count those writes (e.g.,
set disable_state_update=True or set state_dtype_bytes=0) when the
intermediate_states_buffer is not allocated/forwarded; change the call that
currently uses disable_state_update=False and state_dtype_bytes=2 to instead
reflect no state updates for gdn_decode_bf16_state (e.g.,
disable_state_update=True or state_dtype_bytes=0) so TB/s is not inflated.

---

Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1878-1891: The wrapper for BF16-state MTP
(gdn_decode_bf16_state_mtp) currently ignores the provided output buffer for T>1
and allocates/copies inside each iteration; modify the wrapper so it writes into
the passed-in output buffer instead of creating a new one for multi-step runs:
detect the T>1 path in the wrapper that calls gdn_decode_bf16_state_mtp (the
call shown with
A_log,a,dt_bias,...,initial_state_source=state,use_qk_l2norm_in_kernel,scale)
and forward the provided output argument into the underlying implementation or
reuse it for accumulating results, ensuring no per-iteration allocation/copy
happens and the shape/dtype semantics match the existing single-step case.

In `@flashinfer/gdn_decode.py`:
- Around line 509-510: The function signature(s) for gated_delta_rule_mtp (and
the other occurrence at lines ~545-546) changed the default to
disable_state_update=False which silently makes the API mutating by default;
revert the public API to preserve verify-mode by setting
disable_state_update=True in the function signature(s) (or otherwise ensure the
default remains non-mutating), and keep the docstring consistent with that
non-mutating default so callers that omit the flag keep previous semantics.
- Around line 236-250: The call site currently lets
_gated_delta_rule_bf16_state_mtp allocate and return a new tensor which is later
copied into output; instead, forward the caller's preallocated output buffer
into the BF16 MTP backend by passing the existing output tensor as an explicit
argument (e.g., add an output=output or out=output parameter) when calling
_gated_delta_rule_bf16_state_mtp (and update the second duplicate call at lines
251-255 similarly); ensure the backend function signature (
_gated_delta_rule_bf16_state_mtp ) accepts this output/out parameter and writes
into it in-place so no new allocation or device copy occurs on the hot path.
- Around line 206-212: The BF16 fast path (computed into use_bf16_state) must be
disabled when any pooled indices are padding; update the condition that sets
use_bf16_state (which currently checks _GDN_DECODE_BF16_STATE_AVAILABLE,
state_dtype, K == 128, V == 128) to also verify that initial_state_indices
contains no negative values (e.g., ensure initial_state_indices.min() >= 0 or
torch.all(initial_state_indices >= 0)) before enabling the BF16 path so the bf16
pooled implementation is never used when padding slots (negative indices) are
present.

---

Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1625-1653: The test allocates intermediate_states_buffer when
cache_intermediate_states is True but never asserts its contents; update the
test after calling gdn_decode_bf16_state_mtp (the call that sets
intermediate_states_buffer) to compare intermediate_states_buffer against the
step-by-step reference buffer used for the MTP/disable_state_update verification
(the same reference used for outputs), e.g., fetch the expected intermediate
states produced by the stepwise/reference decoder and assert equality (or close
with appropriate dtype tolerance) against intermediate_states_buffer; ensure you
handle None when cache_intermediate_states is False and reuse variables like
intermediate_states_buffer, our_state, and the reference stepwise buffer to
locate the correct data to compare.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f4c83a1b-0f97-4567-9cdd-ef2a909543be

📥 Commits

Reviewing files that changed from the base of the PR and between 2ba157b and 421334a.

📒 Files selected for processing (5)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

Comment on lines +2802 to 2807
elif args.version == "bf16_state":
# BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP
run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm)
else:
# Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state)
# Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state)
run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

--compare is ignored for decode benchmarks.

After adding the bf16_state branch, the fallback path always calls run_all_layouts_benchmark(). --compare --version pretranspose or --compare --version nontranspose therefore never reaches run_comparison_benchmark(), despite the help text and examples promising that behavior.

Suggested patch
-    if args.version == "mtp":
-        # MTP mode: use comparison or flashinfer-only
-        if args.compare:
-            run_comparison_benchmark(args, dtype, use_qk_l2norm)
-        else:
-            run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
-    elif args.version == "bf16_state":
+    if args.version == "bf16_state":
         # BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP
         run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm)
+    elif args.compare:
+        run_comparison_benchmark(args, dtype, use_qk_l2norm)
+    elif args.version == "mtp":
+        run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
     else:
         # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state)
         run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2802 - 2807, The fallback branch
currently always calls run_all_layouts_benchmark(), ignoring args.compare;
adjust the else logic so that when args.compare is true and args.version is
"pretranspose" or "nontranspose" you call run_comparison_benchmark(args, dtype,
use_qk_l2norm) instead of run_all_layouts_benchmark(), otherwise keep calling
run_all_layouts_benchmark(); keep the existing bf16_state branch that calls
run_gdn_decode_bf16_state_benchmark intact.

Forward intermediate_states_buffer, disable_state_update, and
initial_state_indices to the MTP kernel. Pre-allocate buffers
outside the timed lambda to avoid per-call overhead in CUPTI
measurements.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
benchmarks/bench_gdn_decode.py (1)

2839-2850: ⚠️ Potential issue | 🟠 Major

--compare is still ignored for decode benchmarks.

For --version pretranspose and --version nontranspose, this branch still falls through to run_all_layouts_benchmark(), so the advertised single-layout FlashInfer-vs-Triton mode never runs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2839 - 2850, The branch ignores
--compare for single-layout decode modes because versions "pretranspose" and
"nontranspose" fall through to run_all_layouts_benchmark(); change control flow
to handle these versions explicitly: add an elif (or cases) for args.version ==
"pretranspose" or "nontranspose" that mirrors the "mtp" handling by calling
run_comparison_benchmark(args, dtype, use_qk_l2norm) when args.compare is true,
otherwise call run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm),
leaving run_all_layouts_benchmark and run_gdn_decode_bf16_state_benchmark
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2281-2292: The bandwidth calculation double-counts
intermediate-state traffic even when intermediate_states_buffer is not
allocated; update gdn_decode_bytes() usage (or the bandwidth accounting around
it) to only include the [B, T, HV, K, V] intermediate-state bytes when
cache_intermediate_states is true and intermediate_states_buffer is allocated
(i.e., same condition used to create intermediate_states_buffer:
cache_intermediate_states and T > 1). Locate references to
intermediate_states_buffer, cache_intermediate_states, and gdn_decode_bytes()
(also the duplicate accounting block around the second occurrence noted) and
wrap or gate the intermediate-state contribution so TB/s is only incremented
when caching is enabled.
- Around line 1836-1838: The wrapper currently passes k (documented as
[B,T,H,K]) unchanged into gated_delta_rule_mtp while deriving H from q, which
breaks when num_q_heads != num_k_heads for the bf16_state MTP path; update the
benchmark wrapper to detect when bf16_state is chosen and either (a)
expand/reshape k from num_k_heads to num_q_heads by repeating or broadcasting
the key-head dimension so k aligns with q (match q.shape[2]) before calling
gated_delta_rule_mtp(q, k, v, ...), or (b) raise a clear error rejecting
mismatched num_k_heads vs num_q_heads for bf16_state; reference
variables/functions: gated_delta_rule_mtp, q, k, v, bf16_state, num_k_heads,
num_q_heads.

---

Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2839-2850: The branch ignores --compare for single-layout decode
modes because versions "pretranspose" and "nontranspose" fall through to
run_all_layouts_benchmark(); change control flow to handle these versions
explicitly: add an elif (or cases) for args.version == "pretranspose" or
"nontranspose" that mirrors the "mtp" handling by calling
run_comparison_benchmark(args, dtype, use_qk_l2norm) when args.compare is true,
otherwise call run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm),
leaving run_all_layouts_benchmark and run_gdn_decode_bf16_state_benchmark
unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 18cc0a05-1daa-4fed-aa45-f72b4113cf36

📥 Commits

Reviewing files that changed from the base of the PR and between 421334a and 4ac0e09.

📒 Files selected for processing (1)
  • benchmarks/bench_gdn_decode.py

Comment on lines +1836 to 1838
q: torch.Tensor, # [B, T, H_Q, K]
k: torch.Tensor, # [B, T, H_K, K]
v: torch.Tensor, # [B, T, HV, V]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n--- low-level BF16-state MTP contract ---\n'
sed -n '2173,2223p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py

printf '\n--- benchmark wrapper ---\n'
sed -n '1835,1898p' benchmarks/bench_gdn_decode.py

printf '\n--- benchmark tensor allocation ---\n'
sed -n '2259,2320p' benchmarks/bench_gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 6298


Expand k to query heads before calling the low-level BF16-state MTP kernel.

gated_delta_rule_mtp() documents its k parameter as [B, T, H, K] and derives H from q.shape, expecting keys to match the query head dimension. The benchmark wrapper forwards k directly without expanding it from num_k_heads to num_q_heads, causing GQA/MQA cases where num_q_heads != num_k_heads to either mis-benchmark or fail on the BF16-state path. Either materialize query-head-aligned keys in the wrapper or reject mismatched head counts for bf16_state.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1836 - 1838, The wrapper
currently passes k (documented as [B,T,H,K]) unchanged into gated_delta_rule_mtp
while deriving H from q, which breaks when num_q_heads != num_k_heads for the
bf16_state MTP path; update the benchmark wrapper to detect when bf16_state is
chosen and either (a) expand/reshape k from num_k_heads to num_q_heads by
repeating or broadcasting the key-head dimension so k aligns with q (match
q.shape[2]) before calling gated_delta_rule_mtp(q, k, v, ...), or (b) raise a
clear error rejecting mismatched num_k_heads vs num_q_heads for bf16_state;
reference variables/functions: gated_delta_rule_mtp, q, k, v, bf16_state,
num_k_heads, num_q_heads.

@bkryu bkryu added the run-ci label Mar 19, 2026
Resolve conflict in flashinfer/gdn_decode.py by accepting main's
disable_state_update deprecation warning from PR flashinfer-ai#2730.

AI-assisted: Claude Code

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub requested a review from yyihuang as a code owner March 22, 2026 04:04
@ameynaik-hub
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Resolve conflicts with ae9a64d (feat(gdn): add padding):
- kernel: take our rewritten T=1 kernels (old kernels deleted);
  MTP kernel already has `if cache_idx >= 0:` guard
- test: trivial blank-line conflict; new padding test auto-merged

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@kahyunnam
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !450 has been created, and the CI pipeline #46822953 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46822953: 6/20 passed

…ode kernel

Fix 5 CI test failures in test_decode_delta_rule.py after merging main's
padding commit (ae9a64d) into ameyn/gdn_bf16_improvements:

1. MTP kernel padding fix: redirect cache_idx < 0 to slot 0 instead of
   skipping, matching the test expectation that padding slots write to a
   null buffer rather than producing uninitialized output.

2. Small-batch routing fix: route B < ILP_BATCH_THRESHOLD through the
   MTP kernel's T=1 path instead of the cooprow kernel, which has known
   correctness issues at small batch sizes (e.g. B=1, B=2). Remove the
   cooprow dispatch code entirely.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@kahyunnam
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !450 has been updated with latest changes, and the CI pipeline #46893914 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants