feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.#2679
feat(gdn): add BF16 state kernel with MTP support beyond T>4 with intermediate caching.#2679ameynaik-hub wants to merge 7 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR splits and renames the BF16-state GDN decode backend into two variants (single-token BF16 state and BF16-state MTP for multi-token), updates kernel exports/availability flags, changes runtime dispatch to choose T==1 vs T>1 kernels, and adjusts benchmarks/tests/CLI to match the new APIs. Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Decode API
participant Dispatch as Kernel Dispatch
participant BF16 as BF16 State (T=1)
participant MTP as BF16 State MTP (T>1)
Client->>Dispatch: decode(T, dtype=bf16, args...)
activate Dispatch
alt T == 1
Dispatch->>BF16: call _gated_delta_rule_bf16_state(...)
activate BF16
BF16-->>Dispatch: result
deactivate BF16
else T > 1
Dispatch->>MTP: call _gated_delta_rule_bf16_state_mtp(...)
activate MTP
MTP-->>Dispatch: result
deactivate MTP
end
deactivate Dispatch
Dispatch-->>Client: decoded output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the GDN decode functionality by integrating a new, highly optimized BF16 state kernel. This kernel, built with CuTe DSL, provides substantial performance gains across various batch sizes and sequence lengths, particularly for multi-token prediction. The changes involve refactoring existing kernel implementations, updating benchmarking infrastructure, and expanding test coverage to ensure correctness and efficiency of the new BF16 state processing. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance BF16 state kernel for GDN decode, supporting both single-token (T=1) and multi-token prediction (MTP), which demonstrates significant performance improvements. The changes are well-structured, including updates to benchmarks and the addition of comprehensive tests.
My review has identified a bug in one of the new tests that needs to be addressed to ensure correctness. Additionally, I've provided a couple of suggestions to refactor duplicated code in both the benchmark and test files, which will improve the overall maintainability of the codebase.
| if T == 1: | ||
| return gdn_decode_bf16_state( | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| softplus_beta=softplus_beta, | ||
| softplus_threshold=softplus_threshold, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| b=b, | ||
| initial_state_source=state, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm, | ||
| scale=scale, | ||
| ) | ||
| else: | ||
| return gdn_decode_bf16_state_mtp( | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| softplus_beta=softplus_beta, | ||
| softplus_threshold=softplus_threshold, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| b=b, | ||
| initial_state_source=state, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm, | ||
| scale=scale, | ||
| ) |
There was a problem hiding this comment.
The calls to gdn_decode_bf16_state and gdn_decode_bf16_state_mtp share the same set of arguments. This code can be refactored to reduce duplication and improve maintainability by selecting the kernel function first and then calling it with a shared set of keyword arguments.
T = q.shape[1]
kernel_fn = gdn_decode_bf16_state if T == 1 else gdn_decode_bf16_state_mtp
return kernel_fn(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale,
)| def _test_gdn_decode_bf16_state_t1_kernel( | ||
| dtype: str, | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| scale: float, | ||
| alpha: bool, | ||
| beta: bool, | ||
| seed: int | None = None, | ||
| ): |
There was a problem hiding this comment.
There's significant code duplication for test data generation across _test_gdn_decode_bf16_state_kernel, _test_gdn_decode_bf16_state_t1_kernel, and _test_gdn_decode_bf16_state_mtp_kernel. To improve maintainability, consider refactoring the tensor creation logic into a shared helper function or a pytest fixture. This would make the tests cleaner and easier to manage.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
benchmarks/bench_gdn_decode.py (3)
2788-2799:⚠️ Potential issue | 🟠 MajorRestore
--comparerouting for decode versions inmain().For non-MTP paths,
args.compareis currently ignored, so single-layout comparison mode is no longer reachable from the CLI.Suggested fix
- if args.version == "mtp": + if args.version == "mtp": # MTP mode: use comparison or flashinfer-only if args.compare: run_comparison_benchmark(args, dtype, use_qk_l2norm) else: run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) elif args.version == "bf16_state": # BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm) else: - # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state) - run_all_layouts_benchmark(args, dtype, use_qk_l2norm) + # Decode mode: honor --compare flag + if args.compare: + run_comparison_benchmark(args, dtype, use_qk_l2norm) + else: + run_all_layouts_benchmark(args, dtype, use_qk_l2norm)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 2788 - 2799, The non-MTP branches ignore args.compare so the CLI --compare mode is unreachable; update the bf16_state and else branches in main(): for the "bf16_state" branch, if args.compare call run_comparison_benchmark(args, dtype, use_qk_l2norm) else call run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if args.compare call run_comparison_benchmark(...) else call run_all_layouts_benchmark(...). Use the same argument list (args, dtype, use_qk_l2norm) when invoking the functions run_comparison_benchmark, run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare behaves consistently across versions.
1845-1891:⚠️ Potential issue | 🟠 MajorForward preallocated tensors in BF16 MTP wrapper to avoid benchmark allocations.
The wrapper accepts
outputand ignores it in both T=1 and T>1 paths. For T>1, the MTP kernel supports bothoutputandinitial_state_indicesparameters, but the wrapper doesn't forward them. This causes allocations during benchmark timing, skewing measurements.Add
initial_state_indicesparameter to wrapper signature and forward both parameters togdn_decode_bf16_state_mtp(). Updatebench_gdn_decode_bf16_state()to create and passinitial_state_indiceswhen T>1.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 1845 - 1891, The wrapper around gdn_decode_bf16_state currently ignores the preallocated output and doesn't accept or forward initial_state_indices for the multi-timestep path, causing benchmark allocations; update the wrapper signature to add an initial_state_indices: torch.Tensor (or optional) parameter, and when T>1 forward both output and initial_state_indices into gdn_decode_bf16_state_mtp by passing the kernel's output=output and initial_state_indices=initial_state_indices args (keep the T==1 call unchanged), and update bench_gdn_decode_bf16_state() to allocate/create initial_state_indices for the T>1 case and pass it through to the wrapper so no new allocations occur inside the kernel call.
2049-2068:⚠️ Potential issue | 🟠 MajorUse
float32dt_biasfor BF16-state benchmark calls.The
gdn_decode_bf16_statekernel specification requiresdt_biasas[HV] float32, as confirmed by kernel docstrings and test code comments. These benchmark paths currently create it withdtype(typically BF16), which diverges from the intended interface and may benchmark a different numerical path than intended.Suggested fix
- dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")Also applies to: 2256-2260
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 2049 - 2068, The benchmark is passing dt_bias with the wrong dtype for the BF16-state kernel; update the dt_bias tensor construction used in the gdn_decode_bf16_state benchmark calls so dt_bias is created as float32 (torch.float32) with shape [num_sab_heads * head_size] or [HV] as required, then pass that float32 dt_bias into gdn_decode_bf16_state_wrapper (and the second BF16-state call around the other occurrence). Locate the dt_bias variable used in the gdn_decode_bf16_state_wrapper invocations and change its dtype to torch.float32 while keeping device and shape identical.tests/gdn/test_decode_delta_rule.py (1)
824-831:⚠️ Potential issue | 🟠 MajorFix kernel dispatch for
seq_len > 1in pretranspose API test.The test parametrizes
seq_lenwith [1, 2, 3, 4] (line 824), but the direct kernel call at line 902 always invokesgdn_decode_bf16_state, which is the single-token (T=1) kernel. Forseq_len > 1, the test should dispatch togdn_decode_bf16_state_mtpand pass theinitial_state_indicesparameter. Currently this breaks verification of the multi-token path.Suggested fix
- # Direct improved kernel - out_direct = gdn_decode_bf16_state( + # Direct kernel: T=1 uses single-token path, T>1 uses MTP path + direct_kernel = ( + gdn_decode_bf16_state if seq_len == 1 else gdn_decode_bf16_state_mtp + ) + direct_kwargs = dict( A_log=A_log, a=a, dt_bias=dt_bias, softplus_beta=1.0, softplus_threshold=20.0, q=q, k=k, v=v, b=b_tensor, initial_state_source=state_direct, use_qk_l2norm_in_kernel=True, scale=scale, - ) + ) + if seq_len > 1: + direct_kwargs["initial_state_indices"] = torch.arange( + batch_size, dtype=torch.int32, device=device + ) + out_direct = direct_kernel(**direct_kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 824 - 831, The test test_pretranspose_api_uses_gdn_decode_bf16_state incorrectly always calls the single-token kernel gdn_decode_bf16_state; update the dispatch so when seq_len > 1 it calls gdn_decode_bf16_state_mtp and supplies the initial_state_indices argument (preserving the existing call for seq_len == 1). Locate the direct kernel invocation around the test body (the call to gdn_decode_bf16_state) and add a conditional: if seq_len == 1 keep the existing call, else call gdn_decode_bf16_state_mtp with the same parameters plus initial_state_indices so the multi-token verification path is exercised.
🧹 Nitpick comments (2)
tests/gdn/test_decode_delta_rule.py (2)
1148-1179: Validate cached intermediate states in the BF16 MTP test path.When
cache_intermediate_states=True, the test currently doesn’t verify buffer contents. Adding a reference comparison would protect the new intermediate-caching path from silent regressions.Also applies to: 1218-1225
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 1148 - 1179, Add an explicit verification of the cached intermediate states when cache_intermediate_states=True: after calling gdn_decode_bf16_state_mtp with intermediate_states_buffer provided, run a reference call that produces the expected intermediate states (e.g., call gdn_decode_bf16_state_mtp or the float32 equivalent with caching disabled/producing a reference buffer) and compare intermediate_states_buffer to that reference using an appropriate numeric tolerance and dtype/device conversion (use torch.testing.assert_allclose or equivalent) so the BF16 MTP test path actually validates buffer contents; apply the same check in the other test location that mirrors this logic (the block around the second call referenced in the comment).
634-637: Remove or use the unusedalphahelper parameter.
alphais currently unused in the BF16 helper tests, which makes the test surface a bit misleading.Also applies to: 937-940
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 634 - 637, The helper function in tests/gdn/test_decode_delta_rule.py declares an unused parameter alpha alongside beta and seed; either remove alpha from the helper signature and all its call sites in that file (including the BF16 helper usages) or modify the BF16 helper tests to actually use the alpha parameter, ensuring you update the function signature and every invocation consistently; reference the parameter names alpha, beta, seed when making the change so you catch all occurrences (the same unused-alpha issue also appears later in the file around the BF16 helper usages).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2788-2799: The non-MTP branches ignore args.compare so the CLI
--compare mode is unreachable; update the bf16_state and else branches in
main(): for the "bf16_state" branch, if args.compare call
run_comparison_benchmark(args, dtype, use_qk_l2norm) else call
run_gdn_decode_bf16_state_benchmark(...); for the final else branch, if
args.compare call run_comparison_benchmark(...) else call
run_all_layouts_benchmark(...). Use the same argument list (args, dtype,
use_qk_l2norm) when invoking the functions run_comparison_benchmark,
run_gdn_decode_bf16_state_benchmark, and run_all_layouts_benchmark so --compare
behaves consistently across versions.
- Around line 1845-1891: The wrapper around gdn_decode_bf16_state currently
ignores the preallocated output and doesn't accept or forward
initial_state_indices for the multi-timestep path, causing benchmark
allocations; update the wrapper signature to add an initial_state_indices:
torch.Tensor (or optional) parameter, and when T>1 forward both output and
initial_state_indices into gdn_decode_bf16_state_mtp by passing the kernel's
output=output and initial_state_indices=initial_state_indices args (keep the
T==1 call unchanged), and update bench_gdn_decode_bf16_state() to
allocate/create initial_state_indices for the T>1 case and pass it through to
the wrapper so no new allocations occur inside the kernel call.
- Around line 2049-2068: The benchmark is passing dt_bias with the wrong dtype
for the BF16-state kernel; update the dt_bias tensor construction used in the
gdn_decode_bf16_state benchmark calls so dt_bias is created as float32
(torch.float32) with shape [num_sab_heads * head_size] or [HV] as required, then
pass that float32 dt_bias into gdn_decode_bf16_state_wrapper (and the second
BF16-state call around the other occurrence). Locate the dt_bias variable used
in the gdn_decode_bf16_state_wrapper invocations and change its dtype to
torch.float32 while keeping device and shape identical.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 824-831: The test test_pretranspose_api_uses_gdn_decode_bf16_state
incorrectly always calls the single-token kernel gdn_decode_bf16_state; update
the dispatch so when seq_len > 1 it calls gdn_decode_bf16_state_mtp and supplies
the initial_state_indices argument (preserving the existing call for seq_len ==
1). Locate the direct kernel invocation around the test body (the call to
gdn_decode_bf16_state) and add a conditional: if seq_len == 1 keep the existing
call, else call gdn_decode_bf16_state_mtp with the same parameters plus
initial_state_indices so the multi-token verification path is exercised.
---
Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1148-1179: Add an explicit verification of the cached intermediate
states when cache_intermediate_states=True: after calling
gdn_decode_bf16_state_mtp with intermediate_states_buffer provided, run a
reference call that produces the expected intermediate states (e.g., call
gdn_decode_bf16_state_mtp or the float32 equivalent with caching
disabled/producing a reference buffer) and compare intermediate_states_buffer to
that reference using an appropriate numeric tolerance and dtype/device
conversion (use torch.testing.assert_allclose or equivalent) so the BF16 MTP
test path actually validates buffer contents; apply the same check in the other
test location that mirrors this logic (the block around the second call
referenced in the comment).
- Around line 634-637: The helper function in
tests/gdn/test_decode_delta_rule.py declares an unused parameter alpha alongside
beta and seed; either remove alpha from the helper signature and all its call
sites in that file (including the BF16 helper usages) or modify the BF16 helper
tests to actually use the alpha parameter, ensuring you update the function
signature and every invocation consistently; reference the parameter names
alpha, beta, seed when making the change so you catch all occurrences (the same
unused-alpha issue also appears later in the file around the BF16 helper
usages).
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyresults_bf16_optimizations/cooprow_bf16_vs_optimized_fp32_mtp.mdtests/gdn/test_decode_delta_rule.py
Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage. Provides both T=1 (single token) and MTP (multi-token prediction) variants using a cooperative row approach. Key design: - Each warp processes one V-row at a time (4 warps = 4 V-rows/iter) - cp.async pipeline with TILE_V=8 x TILE_K=128 tiles - H state stored as BF16 in memory, FP32 in registers for compute - ILP-optimized variant for large batch sizes (BS>=32) Consolidated from separate cooprow file into canonical gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel. Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1 and MTP (T>1) when state is BF16 and K=V=128. Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP): - BS=1-2: 1.09-1.35x speedup - BS=4-16: 1.24-2.21x speedup (biggest gains) - BS=32-512: 1.62-1.81x steady-state speedup - Peak: 13.8 TFLOPS (BS=512, T=8) vs FP32's 7.9 TFLOPS AI-assisted (Claude Code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
d470d1d to
a906f21
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/bench_gdn_decode.py (1)
1928-1931:⚠️ Potential issue | 🟠 MajorUse float32
dt_biasin BF16-state benchmark paths.BF16-state tests in this PR consistently feed
dt_biasas float32, but these benchmark paths generatedt_biaswithdtype(typically bf16/fp16). That can trigger dtype assertions/casts and distort or fail BF16-state benchmarking.Suggested patch
- dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")- dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")Also applies to: 2257-2260
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 1928 - 1931, The benchmark creates dt_bias with the generic dtype (bf16/fp16) causing dtype mismatches in BF16-state paths; change the dt_bias creation to explicitly use torch.float32 (e.g., dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")) wherever dt_bias is constructed alongside A_log, a, b (the dt_bias variable in the block with A_log, a, b) and apply the same fix to the other identical occurrence later in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Line 2420: The function gated_delta_rule_mtp now sets the parameter
disable_state_update default to False which silently changes behavior; update it
to preserve prior behavior by restoring disable_state_update: bool = True in the
gated_delta_rule_mtp signature (or if the change is intentional, update the
function's docstring and any callers to document and adopt
disable_state_update=False) and ensure the docstring text for
gated_delta_rule_mtp describing the default matches the new default to avoid
mismatch.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 623-635: The helper function _test_gdn_decode_bf16_state_kernel
declares an unused parameter alpha which triggers ARG001; rename alpha to _alpha
(or remove it) to mark it as intentionally unused and silence the linter, and
apply the same rename to the other BF16-state helper(s) in this file that also
declare an unused alpha parameter so all occurrences are fixed.
---
Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1928-1931: The benchmark creates dt_bias with the generic dtype
(bf16/fp16) causing dtype mismatches in BF16-state paths; change the dt_bias
creation to explicitly use torch.float32 (e.g., dt_bias =
torch.randn(num_sab_heads, dtype=torch.float32, device="cuda")) wherever dt_bias
is constructed alongside A_log, a, b (the dt_bias variable in the block with
A_log, a, b) and apply the same fix to the other identical occurrence later in
the file.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pytests/gdn/test_decode_delta_rule.py
| def _test_gdn_decode_bf16_state_kernel( | ||
| dtype: str, | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| seq_len: int, # T=1,2,3,4 | ||
| seq_len: int, | ||
| scale: float, | ||
| alpha: bool, | ||
| beta: bool, | ||
| seed: int | None = None, | ||
| ): |
There was a problem hiding this comment.
Resolve unused alpha parameters in BF16-state helper tests.
alpha is not read in these helper bodies, and Ruff is already flagging it (ARG001). Rename to _alpha (or remove) to avoid lint noise/failures.
Suggested patch
def _test_gdn_decode_bf16_state_kernel(
@@
- alpha: bool,
+ _alpha: bool,
@@
def _test_gdn_decode_bf16_state_t1_kernel(
@@
- alpha: bool,
+ _alpha: bool,Also applies to: 928-939
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 632-632: Unused function argument: alpha
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_decode_delta_rule.py` around lines 623 - 635, The helper
function _test_gdn_decode_bf16_state_kernel declares an unused parameter alpha
which triggers ARG001; rename alpha to _alpha (or remove it) to mark it as
intentionally unused and silence the linter, and apply the same rename to the
other BF16-state helper(s) in this file that also declare an unused alpha
parameter so all occurrences are fixed.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
benchmarks/bench_gdn_decode.py (1)
1878-1891:⚠️ Potential issue | 🟠 MajorBenchmark the MTP kernel with the preallocated
outputtensor.This timed branch still omits
output=output, so each iteration benchmarks extra allocation/copy work instead of just the kernel path.Suggested patch
else: return gdn_decode_bf16_state_mtp( A_log=A_log, a=a, dt_bias=dt_bias, softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, q=q, k=k, v=v, b=b, initial_state_source=state, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale, + output=output, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 1878 - 1891, The benchmark call to gdn_decode_bf16_state_mtp is missing the preallocated output tensor, causing allocations each iteration; update the call site where gdn_decode_bf16_state_mtp(...) is invoked (the function call shown) to pass output=output so the preallocated tensor is used, keeping all other named params (A_log, a, dt_bias, softplus_beta, softplus_threshold, q, k, v, b, initial_state_source=state, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale) unchanged.flashinfer/gdn_decode.py (1)
509-546:⚠️ Potential issue | 🟠 MajorPreserve the previous read-only default for
disable_state_update.Changing the default to
Falsemeans existing callers that omit this argument now mutateinitial_state. That is a silent public-API behavior change.Suggested patch
- disable_state_update: bool = False, + disable_state_update: bool = True, @@ disable_state_update (bool): - If True, the initial state is not updated. Default: ``False``. + If True, the initial state is not updated. Default: ``True``.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 509 - 546, The default for the disable_state_update parameter was changed and must be restored to preserve read-only behavior; in the function gdn_decode (the Gated Delta Rule MTP Kernel that accepts disable_state_update: bool), revert the default value of disable_state_update back to True, update any related docs/parameter description in the function docstring to match (i.e., indicate default True and that omitting the argument prevents mutation of initial_state), and ensure any downstream callers/tests relying on the previous default continue to behave the same.
🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)
1625-1702: Actually assert the cached BF16 intermediate states.When
cache_intermediate_states=True, this helper never checksintermediate_states_buffer. A broken beyond-T>4cache-write path would still pass, which leaves the new feature effectively unverified.Suggested test extension
# Reference: step through tokens with bf16 state ref_state = input_state_ref_bf16.clone() ref_outputs = [] + ref_intermediate_states = [] for t in range(seq_len): ref_o_t, ref_state = decode_delta_rule( q[:, t].float(), k[:, t].float(), @@ use_l2_norm=True, state_dtype=torch.bfloat16, ) ref_outputs.append(ref_o_t) + if cache_intermediate_states: + ref_intermediate_states.append( + ref_state.transpose(-2, -1).contiguous().clone() + ) ref_o = torch.stack(ref_outputs, dim=1).to(dtype_torch) @@ torch.testing.assert_close( our_o.float(), ref_o.float(), atol=atol_o, rtol=rtol_o, msg=f"Output mismatch for MTP BF16 state kernel (B={batch_size}, T={seq_len})", ) + + if cache_intermediate_states and intermediate_states_buffer is not None: + ref_intermediate = torch.stack(ref_intermediate_states, dim=1) + torch.testing.assert_close( + intermediate_states_buffer.float(), + ref_intermediate.float(), + atol=0.02, + rtol=0.01, + msg=( + f"Intermediate-state cache mismatch for MTP BF16 state kernel " + f"(B={batch_size}, T={seq_len})" + ), + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 1625 - 1702, The test never asserts intermediate_states_buffer when cache_intermediate_states=True, so add assertions that the buffer returned/modified by gdn_decode_bf16_state_mtp (intermediate_states_buffer) matches the per-step BF16 intermediate states computed during the reference loop using decode_delta_rule (collect the per-token intermediate state values while building ref_outputs), comparing shapes/dtypes and using tight atol/rtol similar to output checks; reference intermediate_states_buffer, gdn_decode_bf16_state_mtp, decode_delta_rule, our_state and input_state_kernel to locate where to capture and compare the cached states and ensure the cache-write path is actually verified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 206-212: The BF16 fast-path selection (use_bf16_state) must be
guarded against padded/negative pool indices: extend the predicate that sets
use_bf16_state (which currently uses _GDN_DECODE_BF16_STATE_AVAILABLE,
state_dtype, K and V) to also verify that initial_state_indices either is None
or contains no negative entries (e.g. check initial_state_indices is not
provided OR torch.all(initial_state_indices >= 0) is true) before enabling the
BF16 backend, because the BF16 pooled-state backend does not implement
negative-index semantics and will mis-handle padding slots.
- Around line 235-250: The BF16 MTP fast path currently ignores the caller's
preallocated output and allocates a new tensor then copies back; modify the call
to _gated_delta_rule_bf16_state_mtp so it writes directly into the caller's
output buffer (pass the caller's output as an explicit output argument and
ensure shape/dtype/stride compatibility), remove the downstream copy-from-kernel
back into `output`, and preserve existing flags (use_pool,
initial_state/initial_state_indices, use_qk_l2norm, scale_val) when forwarding
to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place into the
provided `output` buffer.
---
Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1878-1891: The benchmark call to gdn_decode_bf16_state_mtp is
missing the preallocated output tensor, causing allocations each iteration;
update the call site where gdn_decode_bf16_state_mtp(...) is invoked (the
function call shown) to pass output=output so the preallocated tensor is used,
keeping all other named params (A_log, a, dt_bias, softplus_beta,
softplus_threshold, q, k, v, b, initial_state_source=state,
use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale) unchanged.
In `@flashinfer/gdn_decode.py`:
- Around line 509-546: The default for the disable_state_update parameter was
changed and must be restored to preserve read-only behavior; in the function
gdn_decode (the Gated Delta Rule MTP Kernel that accepts disable_state_update:
bool), revert the default value of disable_state_update back to True, update any
related docs/parameter description in the function docstring to match (i.e.,
indicate default True and that omitting the argument prevents mutation of
initial_state), and ensure any downstream callers/tests relying on the previous
default continue to behave the same.
---
Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1625-1702: The test never asserts intermediate_states_buffer when
cache_intermediate_states=True, so add assertions that the buffer
returned/modified by gdn_decode_bf16_state_mtp (intermediate_states_buffer)
matches the per-step BF16 intermediate states computed during the reference loop
using decode_delta_rule (collect the per-token intermediate state values while
building ref_outputs), comparing shapes/dtypes and using tight atol/rtol similar
to output checks; reference intermediate_states_buffer,
gdn_decode_bf16_state_mtp, decode_delta_rule, our_state and input_state_kernel
to locate where to capture and compare the cached states and ensure the
cache-write path is actually verified.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cb21f102-081e-40a1-854c-620a96b8d9d3
📒 Files selected for processing (4)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pyflashinfer/gdn_kernels/__init__.pytests/gdn/test_decode_delta_rule.py
| use_bf16_state = ( | ||
| _GDN_DECODE_BF16_STATE_AVAILABLE | ||
| and state_dtype == torch.bfloat16 | ||
| and T in (1, 2, 3, 4) | ||
| and K == 128 | ||
| and V == 128 | ||
| ) | ||
| if use_gdn_decode_klast_bf16_state: | ||
| if use_bf16_state: |
There was a problem hiding this comment.
Don’t route negative pool indices into the BF16 fast path.
This predicate still selects the BF16 backend for pooled bf16 state even when initial_state_indices contains padding slots (-1). That backend does not implement the negative-index semantics, so these calls can read/write the wrong pool row instead of honoring padding.
Suggested guard
use_bf16_state = (
_GDN_DECODE_BF16_STATE_AVAILABLE
and state_dtype == torch.bfloat16
and K == 128
and V == 128
)
+ if use_bf16_state and use_pool and (initial_state_indices < 0).any().item():
+ raise ValueError(
+ "Negative initial_state_indices are only supported with float32 state; "
+ "the BF16 fast path does not support padding slots."
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 206 - 212, The BF16 fast-path
selection (use_bf16_state) must be guarded against padded/negative pool indices:
extend the predicate that sets use_bf16_state (which currently uses
_GDN_DECODE_BF16_STATE_AVAILABLE, state_dtype, K and V) to also verify that
initial_state_indices either is None or contains no negative entries (e.g. check
initial_state_indices is not provided OR torch.all(initial_state_indices >= 0)
is true) before enabling the BF16 backend, because the BF16 pooled-state backend
does not implement negative-index semantics and will mis-handle padding slots.
| # MTP kernel supports T>=1 and pool+indices | ||
| out = _gated_delta_rule_bf16_state_mtp( | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| softplus_beta=1.0, | ||
| softplus_threshold=20.0, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| b=b, | ||
| initial_state_source=initial_state if use_pool else state, | ||
| initial_state_indices=initial_state_indices, | ||
| use_qk_l2norm_in_kernel=use_qk_l2norm, | ||
| scale=scale_val, | ||
| ) |
There was a problem hiding this comment.
Pass the caller’s output buffer into the BF16 MTP backend.
The new T>1 / pool fast path always allocates a fresh output tensor and then copies it back into output. That defeats preallocation and adds avoidable device traffic in the hot decode loop.
Suggested patch
else:
# MTP kernel supports T>=1 and pool+indices
out = _gated_delta_rule_bf16_state_mtp(
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=1.0,
softplus_threshold=20.0,
q=q,
k=k,
v=v,
b=b,
initial_state_source=initial_state if use_pool else state,
initial_state_indices=initial_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
+ output=output,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 235 - 250, The BF16 MTP fast path
currently ignores the caller's preallocated output and allocates a new tensor
then copies back; modify the call to _gated_delta_rule_bf16_state_mtp so it
writes directly into the caller's output buffer (pass the caller's output as an
explicit output argument and ensure shape/dtype/stride compatibility), remove
the downstream copy-from-kernel back into `output`, and preserve existing flags
(use_pool, initial_state/initial_state_indices, use_qk_l2norm, scale_val) when
forwarding to _gated_delta_rule_bf16_state_mtp so the kernel writes in-place
into the provided `output` buffer.
2ba157b to
896ebb3
Compare
896ebb3 to
46306ff
Compare
Resolve 4 conflicting files after main's major refactor of gdn_decode.py from a 2643-line monolith into a 645-line API layer with kernel code extracted into flashinfer/gdn_kernels/ submodules. Conflict resolutions: - gdn_decode.py: Accept main's refactored API layer, port feature branch's BF16 state import renames, updated dispatch logic (T=1 no-pool → bf16_state, else → bf16_state_mtp), and disable_state_update default docstring fix. - gdn_kernels/__init__.py: Merge both - keep main's expanded exports and add feature branch's MTP exports and backward-compat aliases. - gdn_kernels/gdn_decode_bf16_state.py: Accept feature branch's complete rewrite with coop-row kernel approach + MTP kernel. - tests/gdn/test_decode_delta_rule.py: Start from main's version (pool+indices tests, negative indices tests), apply feature branch's renames, dispatch split, new T=1 and MTP test functions, remove CI skip marker. AI-assisted merge resolution. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
46306ff to
421334a
Compare
|
/bot run |
|
@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/bench_gdn_decode.py (1)
2281-2306:⚠️ Potential issue | 🟠 MajorDon’t charge cache traffic in the no-cache BF16 benchmark.
gdn_decode_bytes()adds[B, T, HV, V, K]intermediate-state writes for everyseq_len > 1, but this benchmark never allocates or forwardsintermediate_states_buffer. The reported TB/s is therefore inflated for every BF16 MTP run.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 2281 - 2306, The benchmark is incorrectly charging intermediate-state write traffic for BF16 no-cache runs; update the gdn_decode_bytes call in the benchmark so it does not count those writes (e.g., set disable_state_update=True or set state_dtype_bytes=0) when the intermediate_states_buffer is not allocated/forwarded; change the call that currently uses disable_state_update=False and state_dtype_bytes=2 to instead reflect no state updates for gdn_decode_bf16_state (e.g., disable_state_update=True or state_dtype_bytes=0) so TB/s is not inflated.
♻️ Duplicate comments (4)
flashinfer/gdn_decode.py (3)
509-510:⚠️ Potential issue | 🟠 MajorDon’t silently flip
gated_delta_rule_mtp()to mutating by default.
disable_state_update=Falsestill changes the public default from verify-mode semantics to in-place state mutation for callers that omit the flag. The docstring now matches, but the API break remains.Suggested patch
- disable_state_update: bool = False, + disable_state_update: bool = True, @@ - If True, the initial state is not updated. Default: ``False``. + If True, the initial state is not updated. Default: ``True``.Also applies to: 545-546
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 509 - 510, The function signature(s) for gated_delta_rule_mtp (and the other occurrence at lines ~545-546) changed the default to disable_state_update=False which silently makes the API mutating by default; revert the public API to preserve verify-mode by setting disable_state_update=True in the function signature(s) (or otherwise ensure the default remains non-mutating), and keep the docstring consistent with that non-mutating default so callers that omit the flag keep previous semantics.
236-250:⚠️ Potential issue | 🟠 MajorForward the caller’s
outputbuffer into the BF16 MTP backend.This path still lets
_gated_delta_rule_bf16_state_mtp()allocate a fresh output and then copies it back intooutput, so preallocation does not actually remove the allocation/device copy on the hot path.Suggested patch
out = _gated_delta_rule_bf16_state_mtp( A_log=A_log, a=a, dt_bias=dt_bias, softplus_beta=1.0, softplus_threshold=20.0, q=q, k=k, v=v, b=b, initial_state_source=initial_state if use_pool else state, initial_state_indices=initial_state_indices, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale_val, + output=output, ) output_provided = output is not None target_dtype = output.dtype if output_provided else q.dtype - if output is not None: + if output is not None and out is not output: output.copy_(out) else: output = outAlso applies to: 251-255
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 236 - 250, The call site currently lets _gated_delta_rule_bf16_state_mtp allocate and return a new tensor which is later copied into output; instead, forward the caller's preallocated output buffer into the BF16 MTP backend by passing the existing output tensor as an explicit argument (e.g., add an output=output or out=output parameter) when calling _gated_delta_rule_bf16_state_mtp (and update the second duplicate call at lines 251-255 similarly); ensure the backend function signature ( _gated_delta_rule_bf16_state_mtp ) accepts this output/out parameter and writes into it in-place so no new allocation or device copy occurs on the hot path.
206-212:⚠️ Potential issue | 🟠 MajorGuard the BF16 fast path against padded pool indices.
initial_state_indices < 0still enables the BF16 backend here, but the BF16 pooled path does not implement padding-slot semantics. A bf16 pooled call with-1can therefore read or update the wrong pool row instead of producing padded output.Suggested guard
use_bf16_state = ( _GDN_DECODE_BF16_STATE_AVAILABLE and state_dtype == torch.bfloat16 and K == 128 and V == 128 ) + if use_bf16_state and use_pool and (initial_state_indices < 0).any().item(): + raise ValueError( + "Negative initial_state_indices are only supported with float32 state; " + "the BF16 fast path does not support padding slots." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 206 - 212, The BF16 fast path (computed into use_bf16_state) must be disabled when any pooled indices are padding; update the condition that sets use_bf16_state (which currently checks _GDN_DECODE_BF16_STATE_AVAILABLE, state_dtype, K == 128, V == 128) to also verify that initial_state_indices contains no negative values (e.g., ensure initial_state_indices.min() >= 0 or torch.all(initial_state_indices >= 0)) before enabling the BF16 path so the bf16 pooled implementation is never used when padding slots (negative indices) are present.benchmarks/bench_gdn_decode.py (1)
1878-1891:⚠️ Potential issue | 🟠 MajorUse the preallocated
outputbuffer in the BF16-state MTP benchmark path.The wrapper still ignores its
outputargument forT > 1, so every timed iteration includes an avoidable allocation/copy and the BF16-state numbers are not directly comparable to the other kernels.Suggested patch
return gdn_decode_bf16_state_mtp( A_log=A_log, a=a, dt_bias=dt_bias, softplus_beta=softplus_beta, softplus_threshold=softplus_threshold, q=q, k=k, v=v, b=b, initial_state_source=state, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale, + output=output, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 1878 - 1891, The wrapper for BF16-state MTP (gdn_decode_bf16_state_mtp) currently ignores the provided output buffer for T>1 and allocates/copies inside each iteration; modify the wrapper so it writes into the passed-in output buffer instead of creating a new one for multi-step runs: detect the T>1 path in the wrapper that calls gdn_decode_bf16_state_mtp (the call shown with A_log,a,dt_bias,...,initial_state_source=state,use_qk_l2norm_in_kernel,scale) and forward the provided output argument into the underlying implementation or reuse it for accumulating results, ensuring no per-iteration allocation/copy happens and the shape/dtype semantics match the existing single-step case.
🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)
1625-1653: Actually assert the BF16 MTP intermediate-state cache.When
cache_intermediate_states=True, this helper allocates and passesintermediate_states_buffer, but never compares its contents with the step-by-step reference. A regression in the new caching path would still pass as long as the outputs stay correct.Also applies to: 1660-1680
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 1625 - 1653, The test allocates intermediate_states_buffer when cache_intermediate_states is True but never asserts its contents; update the test after calling gdn_decode_bf16_state_mtp (the call that sets intermediate_states_buffer) to compare intermediate_states_buffer against the step-by-step reference buffer used for the MTP/disable_state_update verification (the same reference used for outputs), e.g., fetch the expected intermediate states produced by the stepwise/reference decoder and assert equality (or close with appropriate dtype tolerance) against intermediate_states_buffer; ensure you handle None when cache_intermediate_states is False and reuse variables like intermediate_states_buffer, our_state, and the reference stepwise buffer to locate the correct data to compare.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2802-2807: The fallback branch currently always calls
run_all_layouts_benchmark(), ignoring args.compare; adjust the else logic so
that when args.compare is true and args.version is "pretranspose" or
"nontranspose" you call run_comparison_benchmark(args, dtype, use_qk_l2norm)
instead of run_all_layouts_benchmark(), otherwise keep calling
run_all_layouts_benchmark(); keep the existing bf16_state branch that calls
run_gdn_decode_bf16_state_benchmark intact.
---
Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2281-2306: The benchmark is incorrectly charging
intermediate-state write traffic for BF16 no-cache runs; update the
gdn_decode_bytes call in the benchmark so it does not count those writes (e.g.,
set disable_state_update=True or set state_dtype_bytes=0) when the
intermediate_states_buffer is not allocated/forwarded; change the call that
currently uses disable_state_update=False and state_dtype_bytes=2 to instead
reflect no state updates for gdn_decode_bf16_state (e.g.,
disable_state_update=True or state_dtype_bytes=0) so TB/s is not inflated.
---
Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1878-1891: The wrapper for BF16-state MTP
(gdn_decode_bf16_state_mtp) currently ignores the provided output buffer for T>1
and allocates/copies inside each iteration; modify the wrapper so it writes into
the passed-in output buffer instead of creating a new one for multi-step runs:
detect the T>1 path in the wrapper that calls gdn_decode_bf16_state_mtp (the
call shown with
A_log,a,dt_bias,...,initial_state_source=state,use_qk_l2norm_in_kernel,scale)
and forward the provided output argument into the underlying implementation or
reuse it for accumulating results, ensuring no per-iteration allocation/copy
happens and the shape/dtype semantics match the existing single-step case.
In `@flashinfer/gdn_decode.py`:
- Around line 509-510: The function signature(s) for gated_delta_rule_mtp (and
the other occurrence at lines ~545-546) changed the default to
disable_state_update=False which silently makes the API mutating by default;
revert the public API to preserve verify-mode by setting
disable_state_update=True in the function signature(s) (or otherwise ensure the
default remains non-mutating), and keep the docstring consistent with that
non-mutating default so callers that omit the flag keep previous semantics.
- Around line 236-250: The call site currently lets
_gated_delta_rule_bf16_state_mtp allocate and return a new tensor which is later
copied into output; instead, forward the caller's preallocated output buffer
into the BF16 MTP backend by passing the existing output tensor as an explicit
argument (e.g., add an output=output or out=output parameter) when calling
_gated_delta_rule_bf16_state_mtp (and update the second duplicate call at lines
251-255 similarly); ensure the backend function signature (
_gated_delta_rule_bf16_state_mtp ) accepts this output/out parameter and writes
into it in-place so no new allocation or device copy occurs on the hot path.
- Around line 206-212: The BF16 fast path (computed into use_bf16_state) must be
disabled when any pooled indices are padding; update the condition that sets
use_bf16_state (which currently checks _GDN_DECODE_BF16_STATE_AVAILABLE,
state_dtype, K == 128, V == 128) to also verify that initial_state_indices
contains no negative values (e.g., ensure initial_state_indices.min() >= 0 or
torch.all(initial_state_indices >= 0)) before enabling the BF16 path so the bf16
pooled implementation is never used when padding slots (negative indices) are
present.
---
Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1625-1653: The test allocates intermediate_states_buffer when
cache_intermediate_states is True but never asserts its contents; update the
test after calling gdn_decode_bf16_state_mtp (the call that sets
intermediate_states_buffer) to compare intermediate_states_buffer against the
step-by-step reference buffer used for the MTP/disable_state_update verification
(the same reference used for outputs), e.g., fetch the expected intermediate
states produced by the stepwise/reference decoder and assert equality (or close
with appropriate dtype tolerance) against intermediate_states_buffer; ensure you
handle None when cache_intermediate_states is False and reuse variables like
intermediate_states_buffer, our_state, and the reference stepwise buffer to
locate the correct data to compare.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f4c83a1b-0f97-4567-9cdd-ef2a909543be
📒 Files selected for processing (5)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pytests/gdn/test_decode_delta_rule.py
| elif args.version == "bf16_state": | ||
| # BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP | ||
| run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm) | ||
| else: | ||
| # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state) | ||
| # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state) | ||
| run_all_layouts_benchmark(args, dtype, use_qk_l2norm) |
There was a problem hiding this comment.
--compare is ignored for decode benchmarks.
After adding the bf16_state branch, the fallback path always calls run_all_layouts_benchmark(). --compare --version pretranspose or --compare --version nontranspose therefore never reaches run_comparison_benchmark(), despite the help text and examples promising that behavior.
Suggested patch
- if args.version == "mtp":
- # MTP mode: use comparison or flashinfer-only
- if args.compare:
- run_comparison_benchmark(args, dtype, use_qk_l2norm)
- else:
- run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
- elif args.version == "bf16_state":
+ if args.version == "bf16_state":
# BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP
run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm)
+ elif args.compare:
+ run_comparison_benchmark(args, dtype, use_qk_l2norm)
+ elif args.version == "mtp":
+ run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
else:
# Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state)
run_all_layouts_benchmark(args, dtype, use_qk_l2norm)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_gdn_decode.py` around lines 2802 - 2807, The fallback branch
currently always calls run_all_layouts_benchmark(), ignoring args.compare;
adjust the else logic so that when args.compare is true and args.version is
"pretranspose" or "nontranspose" you call run_comparison_benchmark(args, dtype,
use_qk_l2norm) instead of run_all_layouts_benchmark(), otherwise keep calling
run_all_layouts_benchmark(); keep the existing bf16_state branch that calls
run_gdn_decode_bf16_state_benchmark intact.
Forward intermediate_states_buffer, disable_state_update, and initial_state_indices to the MTP kernel. Pre-allocate buffers outside the timed lambda to avoid per-call overhead in CUPTI measurements. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
benchmarks/bench_gdn_decode.py (1)
2839-2850:⚠️ Potential issue | 🟠 Major
--compareis still ignored for decode benchmarks.For
--version pretransposeand--version nontranspose, this branch still falls through torun_all_layouts_benchmark(), so the advertised single-layout FlashInfer-vs-Triton mode never runs.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_gdn_decode.py` around lines 2839 - 2850, The branch ignores --compare for single-layout decode modes because versions "pretranspose" and "nontranspose" fall through to run_all_layouts_benchmark(); change control flow to handle these versions explicitly: add an elif (or cases) for args.version == "pretranspose" or "nontranspose" that mirrors the "mtp" handling by calling run_comparison_benchmark(args, dtype, use_qk_l2norm) when args.compare is true, otherwise call run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm), leaving run_all_layouts_benchmark and run_gdn_decode_bf16_state_benchmark unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2281-2292: The bandwidth calculation double-counts
intermediate-state traffic even when intermediate_states_buffer is not
allocated; update gdn_decode_bytes() usage (or the bandwidth accounting around
it) to only include the [B, T, HV, K, V] intermediate-state bytes when
cache_intermediate_states is true and intermediate_states_buffer is allocated
(i.e., same condition used to create intermediate_states_buffer:
cache_intermediate_states and T > 1). Locate references to
intermediate_states_buffer, cache_intermediate_states, and gdn_decode_bytes()
(also the duplicate accounting block around the second occurrence noted) and
wrap or gate the intermediate-state contribution so TB/s is only incremented
when caching is enabled.
- Around line 1836-1838: The wrapper currently passes k (documented as
[B,T,H,K]) unchanged into gated_delta_rule_mtp while deriving H from q, which
breaks when num_q_heads != num_k_heads for the bf16_state MTP path; update the
benchmark wrapper to detect when bf16_state is chosen and either (a)
expand/reshape k from num_k_heads to num_q_heads by repeating or broadcasting
the key-head dimension so k aligns with q (match q.shape[2]) before calling
gated_delta_rule_mtp(q, k, v, ...), or (b) raise a clear error rejecting
mismatched num_k_heads vs num_q_heads for bf16_state; reference
variables/functions: gated_delta_rule_mtp, q, k, v, bf16_state, num_k_heads,
num_q_heads.
---
Duplicate comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2839-2850: The branch ignores --compare for single-layout decode
modes because versions "pretranspose" and "nontranspose" fall through to
run_all_layouts_benchmark(); change control flow to handle these versions
explicitly: add an elif (or cases) for args.version == "pretranspose" or
"nontranspose" that mirrors the "mtp" handling by calling
run_comparison_benchmark(args, dtype, use_qk_l2norm) when args.compare is true,
otherwise call run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm),
leaving run_all_layouts_benchmark and run_gdn_decode_bf16_state_benchmark
unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 18cc0a05-1daa-4fed-aa45-f72b4113cf36
📒 Files selected for processing (1)
benchmarks/bench_gdn_decode.py
| q: torch.Tensor, # [B, T, H_Q, K] | ||
| k: torch.Tensor, # [B, T, H_K, K] | ||
| v: torch.Tensor, # [B, T, HV, V] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '\n--- low-level BF16-state MTP contract ---\n'
sed -n '2173,2223p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py
printf '\n--- benchmark wrapper ---\n'
sed -n '1835,1898p' benchmarks/bench_gdn_decode.py
printf '\n--- benchmark tensor allocation ---\n'
sed -n '2259,2320p' benchmarks/bench_gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 6298
Expand k to query heads before calling the low-level BF16-state MTP kernel.
gated_delta_rule_mtp() documents its k parameter as [B, T, H, K] and derives H from q.shape, expecting keys to match the query head dimension. The benchmark wrapper forwards k directly without expanding it from num_k_heads to num_q_heads, causing GQA/MQA cases where num_q_heads != num_k_heads to either mis-benchmark or fail on the BF16-state path. Either materialize query-head-aligned keys in the wrapper or reject mismatched head counts for bf16_state.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_gdn_decode.py` around lines 1836 - 1838, The wrapper
currently passes k (documented as [B,T,H,K]) unchanged into gated_delta_rule_mtp
while deriving H from q, which breaks when num_q_heads != num_k_heads for the
bf16_state MTP path; update the benchmark wrapper to detect when bf16_state is
chosen and either (a) expand/reshape k from num_k_heads to num_q_heads by
repeating or broadcasting the key-head dimension so k aligns with q (match
q.shape[2]) before calling gated_delta_rule_mtp(q, k, v, ...), or (b) raise a
clear error rejecting mismatched num_k_heads vs num_q_heads for bf16_state;
reference variables/functions: gated_delta_rule_mtp, q, k, v, bf16_state,
num_k_heads, num_q_heads.
Resolve conflict in flashinfer/gdn_decode.py by accepting main's disable_state_update deprecation warning from PR flashinfer-ai#2730. AI-assisted: Claude Code Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
/bot run |
|
@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
Resolve conflicts with ae9a64d (feat(gdn): add padding): - kernel: take our rewritten T=1 kernels (old kernels deleted); MTP kernel already has `if cache_idx >= 0:` guard - test: trivial blank-line conflict; new padding test auto-merged Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
/bot run |
|
[FAILED] Pipeline #46822953: 6/20 passed |
…ode kernel Fix 5 CI test failures in test_decode_delta_rule.py after merging main's padding commit (ae9a64d) into ameyn/gdn_bf16_improvements: 1. MTP kernel padding fix: redirect cache_idx < 0 to slot 0 instead of skipping, matching the test expectation that padding slots write to a null buffer rather than producing uninitialized output. 2. Small-batch routing fix: route B < ILP_BATCH_THRESHOLD through the MTP kernel's T=1 path instead of the cooprow kernel, which has known correctness issues at small batch sizes (e.g. B=1, B=2). Remove the cooprow dispatch code entirely. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
/bot run |
Add a high-performance CuTe DSL kernel for GDN decode with BF16 hidden state storage. Provides both T=1 (single token) and MTP (multi-token prediction) variants using a cooperative row approach.
Key design:
Consolidated from separate cooprow file into canonical gdn_decode_bf16_state.py, replacing the old 32x128 H-chunk kernel. Updated gdn_decode.py dispatch to use BF16 state kernel for both T=1 and MTP (T>1) when state is BF16 and K=V=128.
Benchmark results (B200, Qwen3-Next config, BF16 state MTP vs FP32 MTP):
Cooprow BF16 State vs Optimized FP32 MTP Benchmark
GPU: B200
Config: Qwen3-Next (q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON)
Mode:
cache_intermediate_states=ON,disable_state_update=TrueKernels compared:
gated_delta_rule_bf16state_cooprow_mtp— cooperative row BF16 state kernelgated_delta_rule_mtp— FP32 state kernel with ILP rows (1/2/4/8) + SMEM V caching1. Cooprow BF16 State Kernel Time (us)
2. Optimized FP32 MTP Kernel Time (us)
3. Speedup (FP32 time / BF16 time, >1.0 = BF16 wins)
Summary
AI-assisted (Claude Code)
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Tests
Documentation