Skip to content

Add blackwell GDN prefill impl#2742

Open
dianzhangchen wants to merge 8 commits intoflashinfer-ai:mainfrom
dianzhangchen:gdn_dev
Open

Add blackwell GDN prefill impl#2742
dianzhangchen wants to merge 8 commits intoflashinfer-ai:mainfrom
dianzhangchen:gdn_dev

Conversation

@dianzhangchen
Copy link
Collaborator

@dianzhangchen dianzhangchen commented Mar 10, 2026

📌 Description

Gated Delta Networks (GDN) chunked linear attention kernel for NVIDIA Blackwell (SM100).
Implements the Chunk-wise Gated Delta Rule linear attention using CuTe-DSL, for Blackwell Architecture.

Key Features:

  • Supports Persistent & Non-persistent modes
  • Supports fixed-length and variable-length sequences (cu_seqlens)
  • Supports grouped value attention (GVA) where h_v is a multiple of h_q
  • Supports head dimension (128)
  • Supports input data types (f16/bf16)
  • Supports output data types (f16/bf16)
  • Supports initial_state to provide the initial state
  • Supports output_final_state flag to return the final state
  • State input and output are in f32 (fp16/bf16 not supported yet)

Performance on B200

  • dtype: bfloat16
  • use_initial_state
  • output_final_state=True
  • Fixlen case
batch_size seq_len qk_heads v_heads head_dim dtype gdn_ms fla_ms speedup
1 512 96 96 128 bfloat16 0.0501 0.2119 4.2275
1 1024 96 96 128 bfloat16 0.0869 0.2233 2.5696
1 4096 96 96 128 bfloat16 0.3082 0.7741 2.5118
1 8192 96 96 128 bfloat16 0.6009 1.4919 2.4827
9 512 32 32 128 bfloat16 0.0986 0.3030 3.0740
9 1024 32 32 128 bfloat16 0.1735 0.5529 3.1858
9 4096 32 32 128 bfloat16 0.6189 2.0558 3.3219
9 8192 32 32 128 bfloat16 1.2123 4.0530 3.3434
33 512 32 32 128 bfloat16 0.3461 1.0199 2.9469
33 1024 32 32 128 bfloat16 0.6159 1.9385 3.1474
33 4096 32 32 128 bfloat16 2.2590 7.4736 3.3084
33 8192 32 32 128 bfloat16 4.4938 14.8813 3.3115
1 512 148 148 128 bfloat16 0.0521 0.2208 4.2371
1 1024 148 148 128 bfloat16 0.0955 0.3089 3.2342
1 4096 148 148 128 bfloat16 0.3232 1.1166 3.4551
1 8192 148 148 128 bfloat16 0.6310 2.1957 3.4799

Issues

  • In the current version, we do not parallelize over the sequence-length dimension, so performance is limited when both batch size and head count are small. In such cases, only part of the 148 SMs are used, so the GPU is not fully utilized.
  • Current version supports GVA only; GQA is not supported.
  • State input and output are FP32 only; FP16 is not supported in the current version.

🔍 Related Issues

chunk_gated_delta_rule for Blackwell #2340

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Blackwell GPU support for GDN (Gated Delta Network) linear attention with SM-specific dispatch.
    • Built-in performance benchmark tool (fixed/variable lengths, sweep mode, formatted reports).
    • Exposes optional kernel entry for Blackwell prefill and public symbols for kernel access.
  • Tests

    • Comprehensive GPU test suite: fixed/variable lengths, grouped-value attention variants, final-state outputs, and determinism across runs.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds Blackwell-targeted GDN support: new benchmarking CLI, optional kernel exports, Cutlass SMEM layout helpers, a static tile scheduler with MLIR (de)serialization, SM-specific prefill dispatch, and an extensive GPU test suite for fixed/variable lengths and GVA variants.

Changes

Cohort / File(s) Summary
Benchmarking
benchmarks/bench_blackwell_gdn_prefill.py
New end-to-end benchmarking script: fixed/variable-length benchmarks, sweep mode, CUDA timing, formatted tables, CLI entrypoint, and public exports (benchmark_gdn_fixlen, benchmark_gdn_varlen, benchmark_sweep, print_results_table, main).
Kernel package exports
flashinfer/gdn_kernels/__init__.py
Adds public symbol(s) to module __all__ (optional-import pattern) to expose GDN kernel(s) safely when available.
Blackwell package init
flashinfer/gdn_kernels/blackwell_prefill/__init__.py
Exposes chunk_gated_delta_rule from the Blackwell prefill package if importable, otherwise sets it to None; updates __all__.
SMEM Layout Helpers
flashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.py
New helper functions make_smem_layout_a_kind, make_smem_layout_b_kind, make_smem_layout_epi_kind to compute Cutlass NVGPU tcgen05 SMEM layouts for A, B, and epilog tensors.
Tile Scheduler
flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py
Adds GdnStaticTileSchedulerParams, GdnStaticTileScheduler, MLIR extraction/reconstruction, grid-shape logic, persistent/non-persistent scheduling, and factory functions.
SM-specific dispatcher
flashinfer/gdn_prefill.py
Introduces a public chunk_gated_delta_rule wrapper that dispatches to Hopper or Blackwell implementations based on SM capability (is_sm90a_supported, is_sm100a_supported, is_sm110a_supported), with error handling for unsupported parameter combinations.
Tests
tests/gdn/test_gdn_prefill_blackwell.py
Large new GPU test module: FP32 Python reference implementation, fixtures (CUDA sync/seed), tests for fixed/variable lengths, GVA cases, final-state checks, determinism tests, and skip logic for missing kernels/unsupported SMs.

Sequence Diagram(s)

sequenceDiagram
  participant CLI as "Benchmark CLI"
  participant Host as "Python Host"
  participant Dispatcher as "chunk_gated_delta_rule (wrapper)"
  participant Backend as "Blackwell / Hopper backend"
  participant CUTLASS as "Cutlass (SMEM layouts, scheduler)"
  participant GPU as "GPU (kernel launch)"

  CLI->>Host: parse args, prepare tensors / cu_seqlens
  Host->>Dispatcher: call chunk_gated_delta_rule(q,k,v,g,...)
  Dispatcher->>Backend: select backend (Blackwell or Hopper) by SM
  Backend->>CUTLASS: request SMEM layouts & tile schedule
  CUTLASS->>GPU: launch kernel with layouts & schedule
  GPU-->>CUTLASS: kernel completion/results
  CUTLASS-->>Backend: return outputs/final state
  Backend-->>Dispatcher: outputs/timings
  Dispatcher-->>Host: return outputs
  Host-->>CLI: print results / tables
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

model: qwen3-next

Suggested reviewers

  • kaixih
  • IwakuraRein
  • saltyminty
  • bkryu
  • kahyunnam
  • nv-yunzheq
  • yzh119
  • cyx-6

Poem

🐰 I hopped through tiles and shaped SMEM tight,
I timed each run from warmup to night.
Schedulers bound, kernels hum bright,
Tests nod along in the GPU light.
A rabbit cheers — benchmarks take flight! 🚀

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.82% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add blackwell GDN prefill impl' clearly identifies the main change: implementing a Blackwell GDN prefill kernel, which aligns with the primary objective and code additions.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering objectives, features, performance metrics, known limitations, related issues, and checklist completion. All template sections are addressed with substantive content.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)

24-29: Cover the bf16 path too.

The tolerance constants already account for bf16, but testtype is fixed to fp16, so the bf16 kernel path never runs here. Please parameterize these cases over torch.float16 and torch.bfloat16.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 24 - 29, The test
hardcodes testtype = torch.float16 so the bf16 branch never executes; change the
test to parametrize over both torch.float16 and torch.bfloat16 (e.g.
pytest.mark.parametrize or a small loop) and compute oatol, ortol, satol, srtol
from that parameter (the existing expressions that check `if testtype is
torch.bfloat16` should remain but use the parametrized `testtype`), ensuring the
test runs once for each of the two types and exercises the bf16 kernel path
(update any test function signature that uses `testtype` accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 93-103: The timed benchmark loop diverges from warmup by setting
output_final_state=True but passing None for the final-state buffer, which
changes workload (allocating or skipping materialization); make the timed path
match the warmup by passing the preallocated state_output when calling
chunk_gated_delta_rule with output_final_state=True (use the same arguments as
the warmup: None, h0, output_final_state, None, False, o, state_output) so each
iteration reuses the preallocated tensor and measures the same work.
- Around line 72-78: The code assumes h0 always exists by calling
torch.zeros_like(h0) which fails when use_initial_state is False; change the
state_output initialization to handle the None case (e.g., set state_output =
torch.zeros_like(h0, dtype=torch.float32) if h0 is not None else None or an
appropriately shaped zero tensor derived from v if later code requires a
tensor), updating the lines that set h0 and state_output (symbols:
use_initial_state, h0, v, state_output, torch.zeros_like) so the "no initial
state" path does not raise.

In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.py`:
- Around line 65-68: The locals a_smem_shape_mn_k (and the similar unused tuple
at lines ~122-125) are dead and trigger unused-variable warnings; remove these
unused assignments from gdn_helpers.py or replace them with explicit sanity
checks (e.g., assert statements validating a_smem_shape dimensions) if they were
intended as checks—locate the assignments to a_smem_shape_mn_k and the
corresponding a_smem_shape_k_m and either delete the lines or convert them into
assertions referencing a_smem_shape and loc/ip to preserve the intended
validation.

In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 17-23: The test file eagerly imports chunk_gated_delta_rule which
breaks collection on machines without Blackwell/SM100; update
tests/gdn/test_gdn_prefill_blackwell.py to guard collection by using the repo
helpers (get_compute_capability()/is_sm100a_supported()) or pytest.importorskip
before importing chunk_gated_delta_rule, and import the kernel via the optional
public API path (or via
pytest.importorskip("flashinfer.gdn_kernels.blackwell_prefill") ) so the module
is skipped cleanly when SM100/Blackwell DSL is unavailable; ensure references to
chunk_gated_delta_rule remain after the guarded import.

---

Nitpick comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 24-29: The test hardcodes testtype = torch.float16 so the bf16
branch never executes; change the test to parametrize over both torch.float16
and torch.bfloat16 (e.g. pytest.mark.parametrize or a small loop) and compute
oatol, ortol, satol, srtol from that parameter (the existing expressions that
check `if testtype is torch.bfloat16` should remain but use the parametrized
`testtype`), ensuring the test runs once for each of the two types and exercises
the bf16 kernel path (update any test function signature that uses `testtype`
accordingly).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f3d5740f-2232-4a71-b337-adf4709655f1

📥 Commits

Reviewing files that changed from the base of the PR and between fe06b91 and 61cdf98.

📒 Files selected for processing (6)
  • benchmarks/bench_blackwell_gdn_prefill.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py
  • tests/gdn/test_gdn_prefill_blackwell.py

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new, highly optimized Gated Delta Network (GDN) linear attention kernel tailored for NVIDIA Blackwell GPUs. The implementation, built with CuTe-DSL, aims to enhance the efficiency of linear attention operations by supporting diverse sequence lengths and grouped value attention. It includes dedicated benchmarking and testing to ensure performance gains and correctness, making the advanced GDN capabilities accessible for accelerated deep learning workloads.

Highlights

  • Blackwell GDN Kernel Implementation: Introduced a highly optimized chunked linear attention kernel for Gated Delta Networks (GDN) specifically designed for NVIDIA Blackwell (SM100) architecture, utilizing CuTe-DSL.
  • Feature Support: The new kernel supports persistent and non-persistent modes, fixed-length and variable-length sequences (cu_seqlens), grouped value attention (GVA) where h_v is a multiple of h_q, head dimension of 128, and input/output data types of f16/bf16. It also handles initial states and returns final states.
  • Performance Benchmarking: Added a comprehensive benchmarking tool to evaluate the performance of the Blackwell GDN prefill implementation against a FLA baseline, demonstrating significant speedups (up to 4.2x) across various configurations.
  • Extensive Testing: Included a robust test suite to validate the correctness of the GDN kernel's output and state for both fixed-length and variable-length sequences, including scenarios with grouped value attention.
  • Public API Exposure: The new chunk_gated_delta_rule function has been exposed through the flashinfer.gdn_kernels API for broader usability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_blackwell_gdn_prefill.py
    • Added a new benchmark script for GDN prefill on Blackwell, supporting fixed-length, variable-length, and sweep configurations, with comparison against FLA baseline.
  • flashinfer/gdn_kernels/init.py
    • Updated to expose the new chunk_gated_delta_rule function from the Blackwell prefill kernels.
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.py
    • Added helper functions for shared memory layout generation within the CuTe-DSL framework for GDN kernels.
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py
    • Added a new tile scheduler for GDN kernels, managing work distribution for persistent and non-persistent modes.
  • tests/gdn/test_gdn_prefill_blackwell.py
    • Added comprehensive unit tests for the Blackwell GDN prefill kernel, covering fixed-length, variable-length, and grouped value attention scenarios, and verifying output and state correctness against a reference implementation.
Activity
  • No human activity to report on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Gated Delta Network (GDN) prefill kernel for NVIDIA Blackwell GPUs, implemented using CuTe-DSL. The changes are comprehensive, including the kernel implementation, helper utilities, a tile scheduler, new tests, and benchmarks. The code is well-structured and adds significant new capabilities. My review identifies a few minor areas for improvement, mainly related to code cleanup in helper files, consistency in the benchmark script, and a redundant operation in the test's reference implementation. Overall, this is a great contribution.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)

17-23: ⚠️ Potential issue | 🟠 Major

Guard test collection on SM100 availability.

This import will fail during test collection on non-SM100 machines or builds without the Blackwell DSL. Use pytest.importorskip or the repo's SM100 helpers to skip cleanly on unsupported environments.

As per coding guidelines: "Use flashinfer.utils functions (is_sm100a_supported()) to skip tests on unsupported GPU architectures".

🛠️ Suggested fix
 import pytest
 import torch
 import torch.nn.functional as F
 
-from flashinfer.gdn_kernels.blackwell_prefill.gdn import chunk_gated_delta_rule
+from flashinfer.utils import is_sm100a_supported
+
+# Skip entire module if SM100 is not supported
+pytestmark = pytest.mark.skipif(
+    not is_sm100a_supported(),
+    reason="Blackwell (SM100) not supported on this device"
+)
+
+# Import after guard to prevent collection failures
+chunk_gated_delta_rule = pytest.importorskip(
+    "flashinfer.gdn_kernels.blackwell_prefill.gdn"
+).chunk_gated_delta_rule
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 17 - 23, The test
imports Blackwell-only code unguarded which fails on non-SM100 builds; guard
collection by skipping when SM100/Blackwell DSL is unavailable: at the top of
tests/gdn/test_gdn_prefill_blackwell.py use pytest.importorskip for the
Blackwell module or call flashinfer.utils.is_sm100a_supported() and pytest.skip
if False before importing chunk_gated_delta_rule so the import of
chunk_gated_delta_rule and definition of testtype only occur on supported
hardware.
🧹 Nitpick comments (6)
tests/gdn/test_gdn_prefill_blackwell.py (2)

176-192: Prefix unused unpacked variable with underscore.

The o variable is unpacked but never used in test_fixlen_state. Use _ or _o to indicate intentional discard and silence the linter warning.

♻️ Suggested fix
-        o, state_output = chunk_gated_delta_rule(
+        _o, state_output = chunk_gated_delta_rule(

This also applies to similar patterns in test_varlen_state (line 313), test_fixlen_gva_state (line 444), and test_varlen_gva_state (line 587).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 176 - 192, The unpacked
output variable `o` from the call to chunk_gated_delta_rule is unused and should
be prefixed with an underscore to silence the linter; update the unpackings in
test_fixlen_state, test_varlen_state, test_fixlen_gva_state, and
test_varlen_gva_state so the first element is assigned to `_` or `_o` (e.g.,
replace `o, state_output = chunk_gated_delta_rule(...)` with `_, state_output =
chunk_gated_delta_rule(...)`) while leaving the function call and `state_output`
handling unchanged.

31-40: Use explicit Optional type hint.

PEP 484 prohibits implicit Optional. The scale and initial_state parameters should use explicit Optional[T] or T | None syntax.

♻️ Suggested fix
+from typing import Optional
+
 def recurrent_gated_delta_rule_ref(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     g: torch.Tensor,
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    scale: Optional[float] = None,
+    initial_state: Optional[torch.Tensor] = None,
     output_final_state: bool = False,
 ):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 31 - 40, The function
recurrent_gated_delta_rule_ref has parameters scale and initial_state typed as
defaulting to None; change their annotations to explicit Optional types (e.g.,
Optional[float] and Optional[torch.Tensor] or float | None and torch.Tensor |
None) and import typing.Optional if using that form; update the function
signature accordingly (referencing recurrent_gated_delta_rule_ref, scale, and
initial_state) so the hints comply with PEP 484.
benchmarks/bench_blackwell_gdn_prefill.py (3)

36-39: Consider importing from public API.

The benchmark imports directly from the internal module path. For consistency and to verify the public API works correctly, consider importing via the public interface.

♻️ Suggested change
-from flashinfer.gdn_kernels.blackwell_prefill.gdn import chunk_gated_delta_rule
+from flashinfer.gdn_kernels import chunk_gated_delta_rule
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 36 - 39, The
benchmark imports chunk_gated_delta_rule and chunk_gated_delta_rule_fwd directly
from internal modules; update the imports to use the package public API instead
(import the same symbols via their public export points) so the benchmark
verifies the public interface and avoids internal paths—replace the direct
internal imports of chunk_gated_delta_rule (from
flashinfer.gdn_kernels.blackwell_prefill.gdn) and chunk_gated_delta_rule_fwd
(from fla.ops.gated_delta_rule.chunk) with their corresponding public API
imports.

153-162: Prefer def over lambda assignment.

Static analysis (Ruff E731) flags the lambda assignment. This is a minor style issue but using def is more Pythonic and allows for better debugging.

♻️ Suggested fix
-        fn_fla = lambda: fla_base(
-            q,
-            k,
-            v,
-            g,
-            beta,
-            None,
-            initial_state=h0,
-            output_final_state=output_final_state,
-        )
+        def fn_fla():
+            return fla_base(
+                q,
+                k,
+                v,
+                g,
+                beta,
+                None,
+                initial_state=h0,
+                output_final_state=output_final_state,
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 153 - 162, Replace
the lambda assignment to fn_fla with a proper named function definition: create
a def fn_fla(...): that calls fla_base(q, k, v, g, beta, None, initial_state=h0,
output_final_state=output_final_state) and returns its result, preserving the
same captured variables (q, k, v, g, beta, h0, output_final_state); this removes
the lambda (Ruff E731) and makes debugging/tracing of fn_fla clearer while
keeping behavior identical.

122-148: Consider using flashinfer.testing.bench_gpu_time() for timing.

The benchmark uses manual CUDA events for timing. As per coding guidelines: "Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events."

♻️ Suggested approach
from flashinfer.testing import bench_gpu_time

# Replace manual event timing with:
avg_latency_ms = bench_gpu_time(
    lambda: chunk_gated_delta_rule(
        q, k, v, g, beta,
        None, h0, output_final_state, None, False, o, None,
    ),
    warmup_iters=warmup_iters,
    benchmark_iters=benchmark_iters,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 122 - 148, The timing
block uses manual torch.cuda.Event timing; replace it by calling
flashinfer.testing.bench_gpu_time and pass a lambda that calls
chunk_gated_delta_rule(q, k, v, g, beta, None, h0, output_final_state, None,
False, o, None), remove start_event/end_event, torch.cuda.synchronize, and
elapsed_time/avg computation, and assign avg_latency_ms to the bench_gpu_time
return value; also add the import for bench_gpu_time and forward warmup_iters
and benchmark_iters to the bench_gpu_time call.
flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py (1)

140-150: Magic number in assertion is fragile.

The hard-coded len(values) == 10 is brittle. If the MLIR value structure changes (e.g., adding fields to params or changing coordinate dimensions), this assertion will fail without clear indication of what changed.

♻️ Consider deriving expected length
     def __new_from_mlir_values__(self, values):
-        assert len(values) == 10
+        # Expected: 3 (params) + 1 (current_work_linear_idx) + 3 (blk_coord) + 3 (grid_shape)
+        expected_len = 10
+        assert len(values) == expected_len, f"Expected {expected_len} values, got {len(values)}"
         new_params = cutlass.new_from_mlir_values(self._params, values[0:3])

Or better, compute the expected length from the extraction method to keep them in sync.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py` around lines
140 - 150, The assertion uses a brittle magic number (len(values) == 10);
replace it by computing the expected number of MLIR values from the actual
components so the check stays in sync with structure changes: compute
expected_count as the sum of MLIR-value counts for self._params,
self._current_work_linear_idx, self._blk_coord and self._grid_shape (e.g.
expected_count = cutlass.mlir_value_count(self._params) +
cutlass.mlir_value_count(self._current_work_linear_idx) +
cutlass.mlir_value_count(self._blk_coord) +
cutlass.mlir_value_count(self._grid_shape) or, if that helper doesn’t exist, add
a small helper that inspects each object’s stored mlir values to return their
lengths), then assert len(values) == expected_count before slicing and calling
cutlass.new_from_mlir_values in __new_from_mlir_values__ when constructing
GdnStaticTileScheduler.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 30-37: In __new_from_mlir_values__ (in
GdnStaticTileSchedulerParams reconstruction) the created object omits the
original ip context; update the final constructor call
GdnStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) to also pass
ip=self._ip so the reconstructed object preserves both loc and ip; locate
references to self.is_persistent, self.problem_shape_mbh, self._values_pos and
cutlass.new_from_mlir_values when making the change.

---

Duplicate comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 17-23: The test imports Blackwell-only code unguarded which fails
on non-SM100 builds; guard collection by skipping when SM100/Blackwell DSL is
unavailable: at the top of tests/gdn/test_gdn_prefill_blackwell.py use
pytest.importorskip for the Blackwell module or call
flashinfer.utils.is_sm100a_supported() and pytest.skip if False before importing
chunk_gated_delta_rule so the import of chunk_gated_delta_rule and definition of
testtype only occur on supported hardware.

---

Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 36-39: The benchmark imports chunk_gated_delta_rule and
chunk_gated_delta_rule_fwd directly from internal modules; update the imports to
use the package public API instead (import the same symbols via their public
export points) so the benchmark verifies the public interface and avoids
internal paths—replace the direct internal imports of chunk_gated_delta_rule
(from flashinfer.gdn_kernels.blackwell_prefill.gdn) and
chunk_gated_delta_rule_fwd (from fla.ops.gated_delta_rule.chunk) with their
corresponding public API imports.
- Around line 153-162: Replace the lambda assignment to fn_fla with a proper
named function definition: create a def fn_fla(...): that calls fla_base(q, k,
v, g, beta, None, initial_state=h0, output_final_state=output_final_state) and
returns its result, preserving the same captured variables (q, k, v, g, beta,
h0, output_final_state); this removes the lambda (Ruff E731) and makes
debugging/tracing of fn_fla clearer while keeping behavior identical.
- Around line 122-148: The timing block uses manual torch.cuda.Event timing;
replace it by calling flashinfer.testing.bench_gpu_time and pass a lambda that
calls chunk_gated_delta_rule(q, k, v, g, beta, None, h0, output_final_state,
None, False, o, None), remove start_event/end_event, torch.cuda.synchronize, and
elapsed_time/avg computation, and assign avg_latency_ms to the bench_gpu_time
return value; also add the import for bench_gpu_time and forward warmup_iters
and benchmark_iters to the bench_gpu_time call.

In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 140-150: The assertion uses a brittle magic number (len(values) ==
10); replace it by computing the expected number of MLIR values from the actual
components so the check stays in sync with structure changes: compute
expected_count as the sum of MLIR-value counts for self._params,
self._current_work_linear_idx, self._blk_coord and self._grid_shape (e.g.
expected_count = cutlass.mlir_value_count(self._params) +
cutlass.mlir_value_count(self._current_work_linear_idx) +
cutlass.mlir_value_count(self._blk_coord) +
cutlass.mlir_value_count(self._grid_shape) or, if that helper doesn’t exist, add
a small helper that inspects each object’s stored mlir values to return their
lengths), then assert len(values) == expected_count before slicing and calling
cutlass.new_from_mlir_values in __new_from_mlir_values__ when constructing
GdnStaticTileScheduler.

In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 176-192: The unpacked output variable `o` from the call to
chunk_gated_delta_rule is unused and should be prefixed with an underscore to
silence the linter; update the unpackings in test_fixlen_state,
test_varlen_state, test_fixlen_gva_state, and test_varlen_gva_state so the first
element is assigned to `_` or `_o` (e.g., replace `o, state_output =
chunk_gated_delta_rule(...)` with `_, state_output =
chunk_gated_delta_rule(...)`) while leaving the function call and `state_output`
handling unchanged.
- Around line 31-40: The function recurrent_gated_delta_rule_ref has parameters
scale and initial_state typed as defaulting to None; change their annotations to
explicit Optional types (e.g., Optional[float] and Optional[torch.Tensor] or
float | None and torch.Tensor | None) and import typing.Optional if using that
form; update the function signature accordingly (referencing
recurrent_gated_delta_rule_ref, scale, and initial_state) so the hints comply
with PEP 484.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 28ac11e0-8733-444d-a346-5bd4528d3929

📥 Commits

Reviewing files that changed from the base of the PR and between 61cdf98 and f1c236d.

📒 Files selected for processing (6)
  • benchmarks/bench_blackwell_gdn_prefill.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py
  • tests/gdn/test_gdn_prefill_blackwell.py

@vadiklyutiy vadiklyutiy moved this to In review in Qwen3.5 Mar 11, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tests/gdn/test_gdn_prefill_blackwell.py (2)

50-50: Use explicit Optional[float] type hint.

PEP 484 prohibits implicit Optional. The parameter default None requires explicit optional typing.

♻️ Suggested fix
+from typing import Optional
+
 def recurrent_gated_delta_rule_ref(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     g: torch.Tensor,
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    scale: Optional[float] = None,
+    initial_state: Optional[torch.Tensor] = None,
     output_final_state: bool = False,
 ):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` at line 50, The parameter signature
uses an implicit optional type "scale: float = None"; change it to an explicit
Optional type by updating the parameter to "scale: Optional[float]" and add
"from typing import Optional" to the imports in
tests/gdn/test_gdn_prefill_blackwell.py so the annotation is PEP 484-compliant;
update any other occurrences of the same signature in that file to match.

208-221: Prefix unused o with underscore.

In state-specific tests, the output tensor o is unpacked but not used. Prefix with _ to indicate it's intentionally ignored.

♻️ Suggested fix (apply to lines 208, 348, 485, 634)
-        o, state_output = chunk_gated_delta_rule(
+        _o, state_output = chunk_gated_delta_rule(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 208 - 221, The test
unpacks the return value `o` from chunk_gated_delta_rule but never uses it;
rename it to `_o` (or simply `_`) in each unpacking to signal the unused tensor
and satisfy lint/tests — update occurrences where chunk_gated_delta_rule is
called (e.g., the unpacking at the call that currently assigns `o, state_output
= chunk_gated_delta_rule(...)`) to `_o, state_output =
chunk_gated_delta_rule(...)` (apply the same change to the other similar calls
mentioned).
benchmarks/bench_blackwell_gdn_prefill.py (1)

153-162: Prefer def over lambda assignment.

Ruff E731 flags this lambda assignment. Using a named function improves readability and debuggability.

♻️ Suggested refactor
-        fn_fla = lambda: fla_base(
-            q,
-            k,
-            v,
-            g,
-            beta,
-            None,
-            initial_state=h0,
-            output_final_state=output_final_state,
-        )
+        def fn_fla():
+            return fla_base(
+                q,
+                k,
+                v,
+                g,
+                beta,
+                None,
+                initial_state=h0,
+                output_final_state=output_final_state,
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 153 - 162, The lambda
assigned to fn_fla should be replaced with a named function to satisfy Ruff E731
and improve readability: define a function (e.g., def fn_fla()) that accepts no
arguments and calls fla_base(q, k, v, g, beta, None, initial_state=h0,
output_final_state=output_final_state) and then use that function in place of
the lambda; keep the same references to fla_base, h0, and output_final_state so
behavior is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Line 35: Fix the typo in the header comment: change the comment string
"Blacwell GDN prefill" to "Blackwell GDN prefill" so the module description
reads correctly (look for the comment line containing "Blacwell GDN prefill" in
benchmarks/bench_blackwell_gdn_prefill.py).

---

Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 153-162: The lambda assigned to fn_fla should be replaced with a
named function to satisfy Ruff E731 and improve readability: define a function
(e.g., def fn_fla()) that accepts no arguments and calls fla_base(q, k, v, g,
beta, None, initial_state=h0, output_final_state=output_final_state) and then
use that function in place of the lambda; keep the same references to fla_base,
h0, and output_final_state so behavior is unchanged.

In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Line 50: The parameter signature uses an implicit optional type "scale: float
= None"; change it to an explicit Optional type by updating the parameter to
"scale: Optional[float]" and add "from typing import Optional" to the imports in
tests/gdn/test_gdn_prefill_blackwell.py so the annotation is PEP 484-compliant;
update any other occurrences of the same signature in that file to match.
- Around line 208-221: The test unpacks the return value `o` from
chunk_gated_delta_rule but never uses it; rename it to `_o` (or simply `_`) in
each unpacking to signal the unused tensor and satisfy lint/tests — update
occurrences where chunk_gated_delta_rule is called (e.g., the unpacking at the
call that currently assigns `o, state_output = chunk_gated_delta_rule(...)`) to
`_o, state_output = chunk_gated_delta_rule(...)` (apply the same change to the
other similar calls mentioned).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b0fdf0a2-3caf-4cae-8abe-675099594482

📥 Commits

Reviewing files that changed from the base of the PR and between f1c236d and 13aa6ae.

📒 Files selected for processing (2)
  • benchmarks/bench_blackwell_gdn_prefill.py
  • tests/gdn/test_gdn_prefill_blackwell.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py (1)

130-133: Consider explicit Int32 cast for type consistency.

The similar implementation in grouped_gemm_masked_blackwell.py explicitly casts advance_count to Int32 before arithmetic:

self._current_work_linear_idx += Int32(advance_count) * Int32(self.num_persistent_clusters)

This ensures type consistency in MLIR code generation. The current implementation may work due to implicit conversion, but explicit casting would be more robust.

♻️ Suggested improvement
     def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None):
         if self._is_persistent:
-            self._current_work_linear_idx += advance_count * self.num_persistent_sm
+            self._current_work_linear_idx += Int32(advance_count) * self.num_persistent_sm
         self._is_first_block = False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py` around lines
130 - 133, In advance_to_next_work ensure MLIR integer types are explicit by
casting advance_count and num_persistent_sm to Int32 before the multiplication
and addition to _current_work_linear_idx; specifically, in the method
advance_to_next_work (check _is_persistent branch) replace the implicit
arithmetic with an explicit Int32(advance_count) * Int32(self.num_persistent_sm)
and add that result to _current_work_linear_idx so type consistency matches the
grouped_gemm_masked_blackwell pattern.
benchmarks/bench_blackwell_gdn_prefill.py (3)

476-478: Consider adding GVA constraint validation.

The kernel requires num_v_heads % num_qk_heads == 0 for grouped value attention (GVA). Adding this validation at the CLI level would provide clearer error messages than the kernel warning.

♻️ Suggested addition after head_dim check
     if args.head_dim != 128:
         print(f"Error: head_dim must be 128, got {args.head_dim}")
         sys.exit(1)
+
+    if args.num_v_heads % args.num_qk_heads != 0:
+        print(f"Error: num_v_heads ({args.num_v_heads}) must be divisible by num_qk_heads ({args.num_qk_heads}) for GVA")
+        sys.exit(1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 476 - 478, Add a
CLI-level validation that enforces the grouped value attention constraint by
checking args.num_v_heads % args.num_qk_heads == 0 (after the existing head_dim
check); if the condition fails, print a clear error like "Error: num_v_heads
must be divisible by num_qk_heads, got {args.num_v_heads} and
{args.num_qk_heads}" and call sys.exit(1) so the program exits with a helpful
message before hitting the kernel.

153-162: Prefer def over assigned lambda.

Per Python style guidelines (PEP 8), named functions should use def rather than assigning a lambda to a variable.

♻️ Suggested fix
-        fn_fla = lambda: fla_base(
-            q,
-            k,
-            v,
-            g,
-            beta,
-            None,
-            initial_state=h0,
-            output_final_state=output_final_state,
-        )
+        def fn_fla():
+            return fla_base(
+                q,
+                k,
+                v,
+                g,
+                beta,
+                None,
+                initial_state=h0,
+                output_final_state=output_final_state,
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 153 - 162, Replace
the assigned lambda fn_fla with a proper named function using def to follow PEP
8: create def fn_fla(): that calls fla_base(q, k, v, g, beta, None,
initial_state=h0, output_final_state=output_final_state) and returns its result;
update any references to fn_fla unchanged. Ensure parameter capture uses the
same outer-scope variables (q, k, v, g, beta, h0, output_final_state) as in the
original lambda.

362-364: Consider narrowing exception handling.

Catching bare Exception also catches KeyboardInterrupt and SystemExit, preventing clean script termination. For benchmark robustness while allowing user interrupts, consider catching RuntimeError or a more specific exception type.

♻️ Suggested fix (also apply to line 384)
-        except Exception as e:
+        except (RuntimeError, ValueError, torch.cuda.CudaError) as e:
             print(f"  FAILED: {e}")
             torch.cuda.empty_cache()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 362 - 364, The
current except block uses a broad "except Exception as e:" which also intercepts
KeyboardInterrupt/SystemExit; replace it with a narrower exception handler
(e.g., "except RuntimeError as e:") and keep the same cleanup
(torch.cuda.empty_cache()) and logging, and apply the same change to the similar
handler later in the file (the other except block that prints FAILED and calls
torch.cuda.empty_cache()). If you must catch multiple runtime errors, list them
explicitly or, if you need to preserve re-raising interrupts, add a guard to
re-raise KeyboardInterrupt and SystemExit before handling other exceptions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 142-152: The __new_from_mlir_values__ method reconstructs a
GdnStaticTileScheduler but omits passing the loc and ip context parameters;
update the return to pass the original object's loc and ip (e.g., self.loc and
self.ip) into the GdnStaticTileScheduler constructor so the reconstructed
scheduler preserves the source debug/location context; modify the final return
in __new_from_mlir_values__ of GdnStaticTileScheduler to include loc and ip
along with new_params, new_current_work_linear_idx, new_blk_coord, and
new_grid_shape.

---

Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 476-478: Add a CLI-level validation that enforces the grouped
value attention constraint by checking args.num_v_heads % args.num_qk_heads == 0
(after the existing head_dim check); if the condition fails, print a clear error
like "Error: num_v_heads must be divisible by num_qk_heads, got
{args.num_v_heads} and {args.num_qk_heads}" and call sys.exit(1) so the program
exits with a helpful message before hitting the kernel.
- Around line 153-162: Replace the assigned lambda fn_fla with a proper named
function using def to follow PEP 8: create def fn_fla(): that calls fla_base(q,
k, v, g, beta, None, initial_state=h0, output_final_state=output_final_state)
and returns its result; update any references to fn_fla unchanged. Ensure
parameter capture uses the same outer-scope variables (q, k, v, g, beta, h0,
output_final_state) as in the original lambda.
- Around line 362-364: The current except block uses a broad "except Exception
as e:" which also intercepts KeyboardInterrupt/SystemExit; replace it with a
narrower exception handler (e.g., "except RuntimeError as e:") and keep the same
cleanup (torch.cuda.empty_cache()) and logging, and apply the same change to the
similar handler later in the file (the other except block that prints FAILED and
calls torch.cuda.empty_cache()). If you must catch multiple runtime errors, list
them explicitly or, if you need to preserve re-raising interrupts, add a guard
to re-raise KeyboardInterrupt and SystemExit before handling other exceptions.

In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 130-133: In advance_to_next_work ensure MLIR integer types are
explicit by casting advance_count and num_persistent_sm to Int32 before the
multiplication and addition to _current_work_linear_idx; specifically, in the
method advance_to_next_work (check _is_persistent branch) replace the implicit
arithmetic with an explicit Int32(advance_count) * Int32(self.num_persistent_sm)
and add that result to _current_work_linear_idx so type consistency matches the
grouped_gemm_masked_blackwell pattern.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 11adc012-ba34-4bda-89ae-038b83eab541

📥 Commits

Reviewing files that changed from the base of the PR and between 13aa6ae and fab4e0e.

📒 Files selected for processing (2)
  • benchmarks/bench_blackwell_gdn_prefill.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py

try:
from .blackwell_prefill.gdn import chunk_gated_delta_rule
except ImportError:
chunk_gated_delta_rule = None # type: ignore
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call this function on SM100 in https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_prefill.py#L86? The chunk_gated_delta_rule function in this file should be the only API for the user to call the prefill gdn kernel. It only supports SM90 on the main branch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlu1 sure, thanks for pointing this out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlu1 done,change the api to gdn_prefill

state_output = torch.zeros_like(h0, dtype=torch.float)

o, state_output = chunk_gated_delta_rule(
q.clone(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these clones here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 thanks for pointing this out. For unit tests, we do not trust the code under test. Therefore, we cannot guarantee whether these tensors have been incorrectly modified by the function being tested, So i use clones.

And i have remove these clones.

None,
state_output,
)
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the synchronization

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this out, done, remove synchronization

ref_ht = torch.cat((ref_ht, ref_ht_i), dim=0).contiguous()

ref_ht = torch.transpose(ref_ht, -1, -2).contiguous()
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto and done

None,
state_output,
)
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto and done

state_output = torch.zeros_like(h0, dtype=torch.float)

o, state_output = chunk_gated_delta_rule(
q.clone(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto and done

None,
state_output,
)
torch.cuda.synchronize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dito

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto and done

lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g]
)
B, H, T, K, V = *k.shape, v.shape[-1]
o = torch.zeros(B, H, T, V).to(v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's much faster to directly create a GPU tensor than moving tensors from CPU to GPU, especially in unittests.

Suggested change
o = torch.zeros(B, H, T, V).to(v)
o = torch.zeros(B, H, T, V, device=v.device)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this out, done, q k v is GPU tensor

):
"""Reference implementation of gated delta rule (recurrent version)."""
q, k, v, beta, g = map(
lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will be much faster if we can pass gpu tensors direclty.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this out, done, q k v is GPU tensor


device = "cuda"

q = torch.randn((B, T, H, D), dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create these tensors on GPU

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto and done


# Benchmark
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119 thanks for pointing this out, done, use bench_gpu_time func

Change func api position

Bench use bench_gpu_time func

Update default test config
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)

26-33: ⚠️ Potential issue | 🟠 Major

Derive the skip flag from the actual Blackwell backend, not the wrapper import.

Importing flashinfer.gdn_prefill is not enough here: flashinfer.gdn_kernels.blackwell_prefill can export chunk_gated_delta_rule = None on ImportError, so GDN_BLACKWELL_PREFILL_CUTEDSL can be True even when the Blackwell DSL backend is missing. Please key this off the actual Blackwell symbol (or pytest.importorskip) so SM100/SM110 builders skip cleanly instead of failing at call time.

Also applies to: 82-90

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 26 - 33, The test's
skip flag is derived from importing flashinfer.gdn_prefill instead of the actual
Blackwell backend symbol; update the check to verify the real backend export
(e.g. flashinfer.gdn_kernels.blackwell_prefill.chunk_gated_delta_rule) or use
pytest.importorskip("flashinfer.gdn_kernels.blackwell_prefill") to skip when the
Blackwell DSL is missing, and set GDN_BLACKWELL_PREFILL_CUTEDSL based on the
presence/non-None value of chunk_gated_delta_rule rather than the wrapper import
success.
🧹 Nitpick comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)

36-41: Please exercise the BF16 path somewhere in this suite.

The tests hardcode torch.float16, so none of the new cases cover the other advertised dtype. A small torch.bfloat16 parameterization or smoke subset would catch dtype-specific regressions in the Blackwell kernel.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 36 - 41, The test
currently hardcodes testtype = torch.float16 so BF16 paths are never exercised;
change the test to parametrize testtype over (torch.float16, torch.bfloat16) (or
add a small smoke test that sets testtype = torch.bfloat16) and adjust the
tolerance logic that depends on testtype (oatol, ortol, satol, srtol) to pick
the correct values for each dtype; look for the variable testtype and the
tolerance assignments in the test_gdn_prefill_blackwell tests and update them to
run/assert for both dtypes (use pytest.mark.parametrize on the surrounding test
function or add an additional test function that mirrors the current assertions
but uses torch.bfloat16).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Line 38: The benchmark imports the public dispatcher (chunk_gated_delta_rule)
which will trigger Hopper (SM90) kernels on unsupported hardware; add a
fail-fast GPU check at the top of the module before importing or invoking
chunk_gated_delta_rule (and duplicate checks for the other similar import/usage
blocks around the 428-437 region): query the current CUDA device properties
(e.g., via torch.cuda.get_device_properties(device) or the CUDA runtime API),
extract the SM/compute capability, and if it is not the target SM100 or SM110
family raise a clear RuntimeError (or exit) explaining the benchmark requires
SM100/SM110 hardware so the Hopper path is not executed.

In `@flashinfer/gdn_prefill.py`:
- Around line 203-209: The public docstring promises that g=None and beta=None
default to all-ones, but current code defers this to backend dispatch and raises
NotImplementedError on SM100/SM110; eagerly materialize defaults before dispatch
by checking if g or beta is None and, using total_seq_len and num_sab_heads,
create float32 torch.ones tensors (shape [total_seq_len, num_sab_heads]) so the
backend always receives concrete tensors; apply the same change to the other
occurrence mentioned (around the 259-262 region) so behavior matches the docs
across backends.
- Around line 176-177: The API wrapper chunk_gated_delta_rule currently handles
SM90/SM100/SM110 dispatch but lacks the standard backend-requirement contract;
add the `@backend_requirement` decorator to chunk_gated_delta_rule and to the
other compute-gated wrappers in the same region (the functions around the
240-280 block), and attach the required support helpers to each function:
implement and assign is_compute_capability_supported(cc) to mirror the dispatch
logic (return True for SM90/SM100/SM110 where appropriate) and implement
is_backend_supported() to reflect available backends used by the wrapper; ensure
the decorator import is present and the helper functions are attached as
attributes on the function objects so other code can call
function.is_compute_capability_supported(cc) and
function.is_backend_supported().
- Around line 256-277: The branch that calls chunk_gated_delta_rule_blackwell
must first verify the optional Blackwell backend is available (it may be None
when flashinfer.gdn_kernels.blackwell_prefill failed to import); before invoking
chunk_gated_delta_rule_blackwell(q, k, v, g, beta, ...), check that
chunk_gated_delta_rule_blackwell is not None and if it is None raise a clear
NotImplementedError/RuntimeError like "Blackwell CuTe kernel not available on
this build" (keep the existing gate/beta check), so you fail with an explicit
availability error instead of a 'NoneType' object is not callable'.

---

Duplicate comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 26-33: The test's skip flag is derived from importing
flashinfer.gdn_prefill instead of the actual Blackwell backend symbol; update
the check to verify the real backend export (e.g.
flashinfer.gdn_kernels.blackwell_prefill.chunk_gated_delta_rule) or use
pytest.importorskip("flashinfer.gdn_kernels.blackwell_prefill") to skip when the
Blackwell DSL is missing, and set GDN_BLACKWELL_PREFILL_CUTEDSL based on the
presence/non-None value of chunk_gated_delta_rule rather than the wrapper import
success.

---

Nitpick comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 36-41: The test currently hardcodes testtype = torch.float16 so
BF16 paths are never exercised; change the test to parametrize testtype over
(torch.float16, torch.bfloat16) (or add a small smoke test that sets testtype =
torch.bfloat16) and adjust the tolerance logic that depends on testtype (oatol,
ortol, satol, srtol) to pick the correct values for each dtype; look for the
variable testtype and the tolerance assignments in the
test_gdn_prefill_blackwell tests and update them to run/assert for both dtypes
(use pytest.mark.parametrize on the surrounding test function or add an
additional test function that mirrors the current assertions but uses
torch.bfloat16).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9b400e8a-8a96-489c-930a-cc00a4715bc2

📥 Commits

Reviewing files that changed from the base of the PR and between fab4e0e and c358a07.

📒 Files selected for processing (7)
  • benchmarks/bench_blackwell_gdn_prefill.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/blackwell_prefill/__init__.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn.py
  • flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py
  • flashinfer/gdn_prefill.py
  • tests/gdn/test_gdn_prefill_blackwell.py

@dianzhangchen dianzhangchen requested a review from yzh119 March 17, 2026 02:17
Copy link
Collaborator

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this PR is focused on SM100+ prefill support, but just want to clarify: since this prefill kernel outputs FP32 state while the fast decode path requires BF16 state, users will still need an explicit dtype conversion between prefill and decode until gap 5.4 (prefill BF16 state) is addressed. Is this a known limitation or something we plan to tackle soon?

@dianzhangchen
Copy link
Collaborator Author

dianzhangchen commented Mar 20, 2026

I know this PR is focused on SM100+ prefill support, but just want to clarify: since this prefill kernel outputs FP32 state while the fast decode path requires BF16 state, users will still need an explicit dtype conversion between prefill and decode until gap 5.4 (prefill BF16 state) is addressed. Is this a known limitation or something we plan to tackle soon?

hi @kaixih , this is a known limitation and will be addressed in the next version.

@vadiklyutiy
Copy link

What is preventing us from merging this PR?

return output


@supported_compute_capability([90, 100, 110])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this support B300(sm103)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this should support sm103

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does 03233c0 work for you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for replying. Could we please update @supported_compute_capability to include 103 as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @ZJY0516, i updated the code. It should work now. @yzh119 Thanks for the revision. Since sm103 is supported in CUDA 12.8, I added is_sm103a_supported instead of is_sm100f_supported, which requires CUDA 12.9

@yzh119
Copy link
Collaborator

yzh119 commented Mar 24, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !454 has been created, and the CI pipeline #46876098 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46876098: 1/20 passed

@yzh119
Copy link
Collaborator

yzh119 commented Mar 24, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !454 has been updated with latest changes, and the CI pipeline #46883505 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46883505: 13/20 passed

@aleozlx aleozlx added the v0.6.8 release blocker label for 0.6.7 label Mar 24, 2026
@dianzhangchen
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@dianzhangchen is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

return output


@supported_compute_capability([90, 100, 103, 110])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we remove 110 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven’t tested sm110 yet, so support is not confirmed.

@yzh119
Copy link
Collaborator

yzh119 commented Mar 25, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !454 has been updated with latest changes, and the CI pipeline #46950430 is currently running. I'll report back once the pipeline job completes.

@dianzhangchen
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@dianzhangchen is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46950430: 9/20 passed

@ZJY0516
Copy link
Contributor

ZJY0516 commented Mar 25, 2026

I find that this PR does not perform well for small head configurations.

GPU: NVIDIA B300 SXM6 AC
Model reference: Qwen3.5-397B-A17B (h_k=16, h_v=64, d=128, GVA ratio=4)

Config                          h_qk  h_v  FlashInfer  FLA/Triton   Speedup
───────────────────────────────────────────────────────────────────────────
TP4  B=1  S=8192                   4   16      0.748ms      0.395ms     0.53x ✗
TP4  B=1  S=4096                   4   16      0.388ms      0.213ms     0.55x ✗
TP4  B=1  S=2048                   4   16      0.207ms      0.119ms     0.58x ✗
TP2  B=1  S=8192                   8   32      0.749ms      0.566ms     0.76x ✗
TP1  B=1  S=8192                  16   64      0.747ms      0.971ms     1.30x ✓
Symmetric B=4 S=4096              32   32      0.411ms      1.095ms     2.67x ✓
import torch
import torch.nn.functional as F
import numpy as np
from flashinfer.testing import bench_gpu_time
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_gdn
from fla.ops.gated_delta_rule.chunk import chunk_gated_delta_rule_fwd as fla_gdn

CONFIGS = [
    # (batch, seq_len, h_qk, h_v, d, label)
    # Qwen3.5-397B (h_k=16, h_v=64, d=128) under different TP
    (1, 8192, 4, 16, 128, "TP4  B=1  S=8192"),
    (1, 4096, 4, 16, 128, "TP4  B=1  S=4096"),
    (1, 2048, 4, 16, 128, "TP4  B=1  S=2048"),
    (1, 8192, 8, 32, 128, "TP2  B=1  S=8192"),
    (1, 8192, 16, 64, 128, "TP1  B=1  S=8192"),
    # Original benchmark config (symmetric heads)
    (4, 4096, 32, 32, 128, "Symmetric B=4 S=4096"),
]

WARMUP = 5
ITERS = 20


def bench_flashinfer(B, T, h_qk, h_v, d):
    device = "cuda"
    dtype = torch.float16
    q = torch.randn((B, T, h_qk, d), dtype=dtype, device=device)
    k = F.normalize(
        torch.randn(B, T, h_qk, d, dtype=torch.float32, device=device), p=2, dim=-1
    ).to(dtype)
    v = torch.randn((B, T, h_v, d), dtype=dtype, device=device)
    g = F.logsigmoid(torch.rand(1, T * B, h_v, dtype=torch.float32, device=device))
    beta = torch.rand(1, T * B, h_v, dtype=torch.float32, device=device).sigmoid()
    h0 = torch.randn((B, h_v, d, d), dtype=torch.float32, device=device)
    state_out = torch.zeros_like(h0)

    fn = lambda: fi_gdn(q, k, v, g, beta, None, h0, True, None, False, None, state_out)
    times = bench_gpu_time(fn, enable_cupti=True, dry_run_iters=WARMUP, repeat_iters=ITERS)
    torch.cuda.empty_cache()
    return np.average(times)


def bench_fla(B, T, h_qk, h_v, d):
    device = "cuda"
    dtype = torch.float16
    # FLA doesn't support GVA (h_qk != h_v), expand q/k to h_v
    h = h_v
    q = torch.randn((B, T, h, d), dtype=dtype, device=device)
    k = F.normalize(
        torch.randn(B, T, h, d, dtype=torch.float32, device=device), p=2, dim=-1
    ).to(dtype)
    v = torch.randn((B, T, h_v, d), dtype=dtype, device=device)
    g = F.logsigmoid(torch.rand(1, T * B, h_v, dtype=torch.float32, device=device))
    beta = torch.rand(1, T * B, h_v, dtype=torch.float32, device=device).sigmoid()
    h0 = torch.randn((B, h_v, d, d), dtype=torch.float32, device=device)

    fn = lambda: fla_gdn(q, k, v, g, beta, None, initial_state=h0, output_final_state=True)
    times = bench_gpu_time(fn, enable_cupti=True, dry_run_iters=WARMUP, repeat_iters=ITERS)
    torch.cuda.empty_cache()
    return np.average(times)


def main():
    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Model reference: Qwen3.5-397B-A17B (h_k=16, h_v=64, d=128, GVA ratio=4)")
    print()
    header = f"{'Config':<30s}  {'h_qk':>4s} {'h_v':>4s}  {'FlashInfer':>10s}  {'FLA/Triton':>10s}  {'Speedup':>8s}"
    print(header)
    print("─" * len(header))

    for B, T, h_qk, h_v, d, label in CONFIGS:
        fi_ms = bench_flashinfer(B, T, h_qk, h_v, d)
        fla_ms = bench_fla(B, T, h_qk, h_v, d)
        speedup = fla_ms / fi_ms
        marker = "✓" if speedup > 1.0 else "✗"
        print(
            f"{label:<30s}  {h_qk:>4d} {h_v:>4d}  {fi_ms:>9.3f}ms  {fla_ms:>9.3f}ms  {speedup:>7.2f}x {marker}"
        )

if __name__ == "__main__":
    main()

@dianzhangchen
Copy link
Collaborator Author

I find that this PR does not perform well for small head configurations.

GPU: NVIDIA B300 SXM6 AC
Model reference: Qwen3.5-397B-A17B (h_k=16, h_v=64, d=128, GVA ratio=4)

Config                          h_qk  h_v  FlashInfer  FLA/Triton   Speedup
───────────────────────────────────────────────────────────────────────────
TP4  B=1  S=8192                   4   16      0.748ms      0.395ms     0.53x ✗
TP4  B=1  S=4096                   4   16      0.388ms      0.213ms     0.55x ✗
TP4  B=1  S=2048                   4   16      0.207ms      0.119ms     0.58x ✗
TP2  B=1  S=8192                   8   32      0.749ms      0.566ms     0.76x ✗
TP1  B=1  S=8192                  16   64      0.747ms      0.971ms     1.30x ✓
Symmetric B=4 S=4096              32   32      0.411ms      1.095ms     2.67x ✓
import torch
import torch.nn.functional as F
import numpy as np
from flashinfer.testing import bench_gpu_time
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_gdn
from fla.ops.gated_delta_rule.chunk import chunk_gated_delta_rule_fwd as fla_gdn

CONFIGS = [
    # (batch, seq_len, h_qk, h_v, d, label)
    # Qwen3.5-397B (h_k=16, h_v=64, d=128) under different TP
    (1, 8192, 4, 16, 128, "TP4  B=1  S=8192"),
    (1, 4096, 4, 16, 128, "TP4  B=1  S=4096"),
    (1, 2048, 4, 16, 128, "TP4  B=1  S=2048"),
    (1, 8192, 8, 32, 128, "TP2  B=1  S=8192"),
    (1, 8192, 16, 64, 128, "TP1  B=1  S=8192"),
    # Original benchmark config (symmetric heads)
    (4, 4096, 32, 32, 128, "Symmetric B=4 S=4096"),
]

WARMUP = 5
ITERS = 20


def bench_flashinfer(B, T, h_qk, h_v, d):
    device = "cuda"
    dtype = torch.float16
    q = torch.randn((B, T, h_qk, d), dtype=dtype, device=device)
    k = F.normalize(
        torch.randn(B, T, h_qk, d, dtype=torch.float32, device=device), p=2, dim=-1
    ).to(dtype)
    v = torch.randn((B, T, h_v, d), dtype=dtype, device=device)
    g = F.logsigmoid(torch.rand(1, T * B, h_v, dtype=torch.float32, device=device))
    beta = torch.rand(1, T * B, h_v, dtype=torch.float32, device=device).sigmoid()
    h0 = torch.randn((B, h_v, d, d), dtype=torch.float32, device=device)
    state_out = torch.zeros_like(h0)

    fn = lambda: fi_gdn(q, k, v, g, beta, None, h0, True, None, False, None, state_out)
    times = bench_gpu_time(fn, enable_cupti=True, dry_run_iters=WARMUP, repeat_iters=ITERS)
    torch.cuda.empty_cache()
    return np.average(times)


def bench_fla(B, T, h_qk, h_v, d):
    device = "cuda"
    dtype = torch.float16
    # FLA doesn't support GVA (h_qk != h_v), expand q/k to h_v
    h = h_v
    q = torch.randn((B, T, h, d), dtype=dtype, device=device)
    k = F.normalize(
        torch.randn(B, T, h, d, dtype=torch.float32, device=device), p=2, dim=-1
    ).to(dtype)
    v = torch.randn((B, T, h_v, d), dtype=dtype, device=device)
    g = F.logsigmoid(torch.rand(1, T * B, h_v, dtype=torch.float32, device=device))
    beta = torch.rand(1, T * B, h_v, dtype=torch.float32, device=device).sigmoid()
    h0 = torch.randn((B, h_v, d, d), dtype=torch.float32, device=device)

    fn = lambda: fla_gdn(q, k, v, g, beta, None, initial_state=h0, output_final_state=True)
    times = bench_gpu_time(fn, enable_cupti=True, dry_run_iters=WARMUP, repeat_iters=ITERS)
    torch.cuda.empty_cache()
    return np.average(times)


def main():
    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Model reference: Qwen3.5-397B-A17B (h_k=16, h_v=64, d=128, GVA ratio=4)")
    print()
    header = f"{'Config':<30s}  {'h_qk':>4s} {'h_v':>4s}  {'FlashInfer':>10s}  {'FLA/Triton':>10s}  {'Speedup':>8s}"
    print(header)
    print("─" * len(header))

    for B, T, h_qk, h_v, d, label in CONFIGS:
        fi_ms = bench_flashinfer(B, T, h_qk, h_v, d)
        fla_ms = bench_fla(B, T, h_qk, h_v, d)
        speedup = fla_ms / fi_ms
        marker = "✓" if speedup > 1.0 else "✗"
        print(
            f"{label:<30s}  {h_qk:>4d} {h_v:>4d}  {fi_ms:>9.3f}ms  {fla_ms:>9.3f}ms  {speedup:>7.2f}x {marker}"
        )

if __name__ == "__main__":
    main()

hi @ZJY0516 , this is a known issue, as mentioned in the PR: "In the current version, we do not parallelize over the sequence-length dimension, so performance is limited when both batch size and head count are small. In such cases, only part of the 148 SMs are used, so the GPU is not fully utilized."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model: qwen3.5 v0.6.8 release blocker label for 0.6.7

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants