Add blackwell GDN prefill impl#2742
Add blackwell GDN prefill impl#2742dianzhangchen wants to merge 8 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds Blackwell-targeted GDN support: new benchmarking CLI, optional kernel exports, Cutlass SMEM layout helpers, a static tile scheduler with MLIR (de)serialization, SM-specific prefill dispatch, and an extensive GPU test suite for fixed/variable lengths and GVA variants. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as "Benchmark CLI"
participant Host as "Python Host"
participant Dispatcher as "chunk_gated_delta_rule (wrapper)"
participant Backend as "Blackwell / Hopper backend"
participant CUTLASS as "Cutlass (SMEM layouts, scheduler)"
participant GPU as "GPU (kernel launch)"
CLI->>Host: parse args, prepare tensors / cu_seqlens
Host->>Dispatcher: call chunk_gated_delta_rule(q,k,v,g,...)
Dispatcher->>Backend: select backend (Blackwell or Hopper) by SM
Backend->>CUTLASS: request SMEM layouts & tile schedule
CUTLASS->>GPU: launch kernel with layouts & schedule
GPU-->>CUTLASS: kernel completion/results
CUTLASS-->>Backend: return outputs/final state
Backend-->>Dispatcher: outputs/timings
Dispatcher-->>Host: return outputs
Host-->>CLI: print results / tables
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)
24-29: Cover the bf16 path too.The tolerance constants already account for bf16, but
testtypeis fixed to fp16, so the bf16 kernel path never runs here. Please parameterize these cases overtorch.float16andtorch.bfloat16.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 24 - 29, The test hardcodes testtype = torch.float16 so the bf16 branch never executes; change the test to parametrize over both torch.float16 and torch.bfloat16 (e.g. pytest.mark.parametrize or a small loop) and compute oatol, ortol, satol, srtol from that parameter (the existing expressions that check `if testtype is torch.bfloat16` should remain but use the parametrized `testtype`), ensuring the test runs once for each of the two types and exercises the bf16 kernel path (update any test function signature that uses `testtype` accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 93-103: The timed benchmark loop diverges from warmup by setting
output_final_state=True but passing None for the final-state buffer, which
changes workload (allocating or skipping materialization); make the timed path
match the warmup by passing the preallocated state_output when calling
chunk_gated_delta_rule with output_final_state=True (use the same arguments as
the warmup: None, h0, output_final_state, None, False, o, state_output) so each
iteration reuses the preallocated tensor and measures the same work.
- Around line 72-78: The code assumes h0 always exists by calling
torch.zeros_like(h0) which fails when use_initial_state is False; change the
state_output initialization to handle the None case (e.g., set state_output =
torch.zeros_like(h0, dtype=torch.float32) if h0 is not None else None or an
appropriately shaped zero tensor derived from v if later code requires a
tensor), updating the lines that set h0 and state_output (symbols:
use_initial_state, h0, v, state_output, torch.zeros_like) so the "no initial
state" path does not raise.
In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.py`:
- Around line 65-68: The locals a_smem_shape_mn_k (and the similar unused tuple
at lines ~122-125) are dead and trigger unused-variable warnings; remove these
unused assignments from gdn_helpers.py or replace them with explicit sanity
checks (e.g., assert statements validating a_smem_shape dimensions) if they were
intended as checks—locate the assignments to a_smem_shape_mn_k and the
corresponding a_smem_shape_k_m and either delete the lines or convert them into
assertions referencing a_smem_shape and loc/ip to preserve the intended
validation.
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 17-23: The test file eagerly imports chunk_gated_delta_rule which
breaks collection on machines without Blackwell/SM100; update
tests/gdn/test_gdn_prefill_blackwell.py to guard collection by using the repo
helpers (get_compute_capability()/is_sm100a_supported()) or pytest.importorskip
before importing chunk_gated_delta_rule, and import the kernel via the optional
public API path (or via
pytest.importorskip("flashinfer.gdn_kernels.blackwell_prefill") ) so the module
is skipped cleanly when SM100/Blackwell DSL is unavailable; ensure references to
chunk_gated_delta_rule remain after the guarded import.
---
Nitpick comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 24-29: The test hardcodes testtype = torch.float16 so the bf16
branch never executes; change the test to parametrize over both torch.float16
and torch.bfloat16 (e.g. pytest.mark.parametrize or a small loop) and compute
oatol, ortol, satol, srtol from that parameter (the existing expressions that
check `if testtype is torch.bfloat16` should remain but use the parametrized
`testtype`), ensuring the test runs once for each of the two types and exercises
the bf16 kernel path (update any test function signature that uses `testtype`
accordingly).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f3d5740f-2232-4a71-b337-adf4709655f1
📒 Files selected for processing (6)
benchmarks/bench_blackwell_gdn_prefill.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/blackwell_prefill/gdn.pyflashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.pyflashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.pytests/gdn/test_gdn_prefill_blackwell.py
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new, highly optimized Gated Delta Network (GDN) linear attention kernel tailored for NVIDIA Blackwell GPUs. The implementation, built with CuTe-DSL, aims to enhance the efficiency of linear attention operations by supporting diverse sequence lengths and grouped value attention. It includes dedicated benchmarking and testing to ensure performance gains and correctness, making the advanced GDN capabilities accessible for accelerated deep learning workloads. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new Gated Delta Network (GDN) prefill kernel for NVIDIA Blackwell GPUs, implemented using CuTe-DSL. The changes are comprehensive, including the kernel implementation, helper utilities, a tile scheduler, new tests, and benchmarks. The code is well-structured and adds significant new capabilities. My review identifies a few minor areas for improvement, mainly related to code cleanup in helper files, consistency in the benchmark script, and a redundant operation in the test's reference implementation. Overall, this is a great contribution.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)
17-23:⚠️ Potential issue | 🟠 MajorGuard test collection on SM100 availability.
This import will fail during test collection on non-SM100 machines or builds without the Blackwell DSL. Use
pytest.importorskipor the repo's SM100 helpers to skip cleanly on unsupported environments.As per coding guidelines: "Use flashinfer.utils functions (
is_sm100a_supported()) to skip tests on unsupported GPU architectures".🛠️ Suggested fix
import pytest import torch import torch.nn.functional as F -from flashinfer.gdn_kernels.blackwell_prefill.gdn import chunk_gated_delta_rule +from flashinfer.utils import is_sm100a_supported + +# Skip entire module if SM100 is not supported +pytestmark = pytest.mark.skipif( + not is_sm100a_supported(), + reason="Blackwell (SM100) not supported on this device" +) + +# Import after guard to prevent collection failures +chunk_gated_delta_rule = pytest.importorskip( + "flashinfer.gdn_kernels.blackwell_prefill.gdn" +).chunk_gated_delta_rule🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 17 - 23, The test imports Blackwell-only code unguarded which fails on non-SM100 builds; guard collection by skipping when SM100/Blackwell DSL is unavailable: at the top of tests/gdn/test_gdn_prefill_blackwell.py use pytest.importorskip for the Blackwell module or call flashinfer.utils.is_sm100a_supported() and pytest.skip if False before importing chunk_gated_delta_rule so the import of chunk_gated_delta_rule and definition of testtype only occur on supported hardware.
🧹 Nitpick comments (6)
tests/gdn/test_gdn_prefill_blackwell.py (2)
176-192: Prefix unused unpacked variable with underscore.The
ovariable is unpacked but never used intest_fixlen_state. Use_or_oto indicate intentional discard and silence the linter warning.♻️ Suggested fix
- o, state_output = chunk_gated_delta_rule( + _o, state_output = chunk_gated_delta_rule(This also applies to similar patterns in
test_varlen_state(line 313),test_fixlen_gva_state(line 444), andtest_varlen_gva_state(line 587).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 176 - 192, The unpacked output variable `o` from the call to chunk_gated_delta_rule is unused and should be prefixed with an underscore to silence the linter; update the unpackings in test_fixlen_state, test_varlen_state, test_fixlen_gva_state, and test_varlen_gva_state so the first element is assigned to `_` or `_o` (e.g., replace `o, state_output = chunk_gated_delta_rule(...)` with `_, state_output = chunk_gated_delta_rule(...)`) while leaving the function call and `state_output` handling unchanged.
31-40: Use explicitOptionaltype hint.PEP 484 prohibits implicit
Optional. Thescaleandinitial_stateparameters should use explicitOptional[T]orT | Nonesyntax.♻️ Suggested fix
+from typing import Optional + def recurrent_gated_delta_rule_ref( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, g: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, ):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 31 - 40, The function recurrent_gated_delta_rule_ref has parameters scale and initial_state typed as defaulting to None; change their annotations to explicit Optional types (e.g., Optional[float] and Optional[torch.Tensor] or float | None and torch.Tensor | None) and import typing.Optional if using that form; update the function signature accordingly (referencing recurrent_gated_delta_rule_ref, scale, and initial_state) so the hints comply with PEP 484.benchmarks/bench_blackwell_gdn_prefill.py (3)
36-39: Consider importing from public API.The benchmark imports directly from the internal module path. For consistency and to verify the public API works correctly, consider importing via the public interface.
♻️ Suggested change
-from flashinfer.gdn_kernels.blackwell_prefill.gdn import chunk_gated_delta_rule +from flashinfer.gdn_kernels import chunk_gated_delta_rule🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 36 - 39, The benchmark imports chunk_gated_delta_rule and chunk_gated_delta_rule_fwd directly from internal modules; update the imports to use the package public API instead (import the same symbols via their public export points) so the benchmark verifies the public interface and avoids internal paths—replace the direct internal imports of chunk_gated_delta_rule (from flashinfer.gdn_kernels.blackwell_prefill.gdn) and chunk_gated_delta_rule_fwd (from fla.ops.gated_delta_rule.chunk) with their corresponding public API imports.
153-162: Preferdefover lambda assignment.Static analysis (Ruff E731) flags the lambda assignment. This is a minor style issue but using
defis more Pythonic and allows for better debugging.♻️ Suggested fix
- fn_fla = lambda: fla_base( - q, - k, - v, - g, - beta, - None, - initial_state=h0, - output_final_state=output_final_state, - ) + def fn_fla(): + return fla_base( + q, + k, + v, + g, + beta, + None, + initial_state=h0, + output_final_state=output_final_state, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 153 - 162, Replace the lambda assignment to fn_fla with a proper named function definition: create a def fn_fla(...): that calls fla_base(q, k, v, g, beta, None, initial_state=h0, output_final_state=output_final_state) and returns its result, preserving the same captured variables (q, k, v, g, beta, h0, output_final_state); this removes the lambda (Ruff E731) and makes debugging/tracing of fn_fla clearer while keeping behavior identical.
122-148: Consider usingflashinfer.testing.bench_gpu_time()for timing.The benchmark uses manual CUDA events for timing. As per coding guidelines: "Use
flashinfer.testing.bench_gpu_time()for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events."♻️ Suggested approach
from flashinfer.testing import bench_gpu_time # Replace manual event timing with: avg_latency_ms = bench_gpu_time( lambda: chunk_gated_delta_rule( q, k, v, g, beta, None, h0, output_final_state, None, False, o, None, ), warmup_iters=warmup_iters, benchmark_iters=benchmark_iters, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 122 - 148, The timing block uses manual torch.cuda.Event timing; replace it by calling flashinfer.testing.bench_gpu_time and pass a lambda that calls chunk_gated_delta_rule(q, k, v, g, beta, None, h0, output_final_state, None, False, o, None), remove start_event/end_event, torch.cuda.synchronize, and elapsed_time/avg computation, and assign avg_latency_ms to the bench_gpu_time return value; also add the import for bench_gpu_time and forward warmup_iters and benchmark_iters to the bench_gpu_time call.flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py (1)
140-150: Magic number in assertion is fragile.The hard-coded
len(values) == 10is brittle. If the MLIR value structure changes (e.g., adding fields to params or changing coordinate dimensions), this assertion will fail without clear indication of what changed.♻️ Consider deriving expected length
def __new_from_mlir_values__(self, values): - assert len(values) == 10 + # Expected: 3 (params) + 1 (current_work_linear_idx) + 3 (blk_coord) + 3 (grid_shape) + expected_len = 10 + assert len(values) == expected_len, f"Expected {expected_len} values, got {len(values)}" new_params = cutlass.new_from_mlir_values(self._params, values[0:3])Or better, compute the expected length from the extraction method to keep them in sync.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py` around lines 140 - 150, The assertion uses a brittle magic number (len(values) == 10); replace it by computing the expected number of MLIR values from the actual components so the check stays in sync with structure changes: compute expected_count as the sum of MLIR-value counts for self._params, self._current_work_linear_idx, self._blk_coord and self._grid_shape (e.g. expected_count = cutlass.mlir_value_count(self._params) + cutlass.mlir_value_count(self._current_work_linear_idx) + cutlass.mlir_value_count(self._blk_coord) + cutlass.mlir_value_count(self._grid_shape) or, if that helper doesn’t exist, add a small helper that inspects each object’s stored mlir values to return their lengths), then assert len(values) == expected_count before slicing and calling cutlass.new_from_mlir_values in __new_from_mlir_values__ when constructing GdnStaticTileScheduler.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 30-37: In __new_from_mlir_values__ (in
GdnStaticTileSchedulerParams reconstruction) the created object omits the
original ip context; update the final constructor call
GdnStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) to also pass
ip=self._ip so the reconstructed object preserves both loc and ip; locate
references to self.is_persistent, self.problem_shape_mbh, self._values_pos and
cutlass.new_from_mlir_values when making the change.
---
Duplicate comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 17-23: The test imports Blackwell-only code unguarded which fails
on non-SM100 builds; guard collection by skipping when SM100/Blackwell DSL is
unavailable: at the top of tests/gdn/test_gdn_prefill_blackwell.py use
pytest.importorskip for the Blackwell module or call
flashinfer.utils.is_sm100a_supported() and pytest.skip if False before importing
chunk_gated_delta_rule so the import of chunk_gated_delta_rule and definition of
testtype only occur on supported hardware.
---
Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 36-39: The benchmark imports chunk_gated_delta_rule and
chunk_gated_delta_rule_fwd directly from internal modules; update the imports to
use the package public API instead (import the same symbols via their public
export points) so the benchmark verifies the public interface and avoids
internal paths—replace the direct internal imports of chunk_gated_delta_rule
(from flashinfer.gdn_kernels.blackwell_prefill.gdn) and
chunk_gated_delta_rule_fwd (from fla.ops.gated_delta_rule.chunk) with their
corresponding public API imports.
- Around line 153-162: Replace the lambda assignment to fn_fla with a proper
named function definition: create a def fn_fla(...): that calls fla_base(q, k,
v, g, beta, None, initial_state=h0, output_final_state=output_final_state) and
returns its result, preserving the same captured variables (q, k, v, g, beta,
h0, output_final_state); this removes the lambda (Ruff E731) and makes
debugging/tracing of fn_fla clearer while keeping behavior identical.
- Around line 122-148: The timing block uses manual torch.cuda.Event timing;
replace it by calling flashinfer.testing.bench_gpu_time and pass a lambda that
calls chunk_gated_delta_rule(q, k, v, g, beta, None, h0, output_final_state,
None, False, o, None), remove start_event/end_event, torch.cuda.synchronize, and
elapsed_time/avg computation, and assign avg_latency_ms to the bench_gpu_time
return value; also add the import for bench_gpu_time and forward warmup_iters
and benchmark_iters to the bench_gpu_time call.
In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 140-150: The assertion uses a brittle magic number (len(values) ==
10); replace it by computing the expected number of MLIR values from the actual
components so the check stays in sync with structure changes: compute
expected_count as the sum of MLIR-value counts for self._params,
self._current_work_linear_idx, self._blk_coord and self._grid_shape (e.g.
expected_count = cutlass.mlir_value_count(self._params) +
cutlass.mlir_value_count(self._current_work_linear_idx) +
cutlass.mlir_value_count(self._blk_coord) +
cutlass.mlir_value_count(self._grid_shape) or, if that helper doesn’t exist, add
a small helper that inspects each object’s stored mlir values to return their
lengths), then assert len(values) == expected_count before slicing and calling
cutlass.new_from_mlir_values in __new_from_mlir_values__ when constructing
GdnStaticTileScheduler.
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 176-192: The unpacked output variable `o` from the call to
chunk_gated_delta_rule is unused and should be prefixed with an underscore to
silence the linter; update the unpackings in test_fixlen_state,
test_varlen_state, test_fixlen_gva_state, and test_varlen_gva_state so the first
element is assigned to `_` or `_o` (e.g., replace `o, state_output =
chunk_gated_delta_rule(...)` with `_, state_output =
chunk_gated_delta_rule(...)`) while leaving the function call and `state_output`
handling unchanged.
- Around line 31-40: The function recurrent_gated_delta_rule_ref has parameters
scale and initial_state typed as defaulting to None; change their annotations to
explicit Optional types (e.g., Optional[float] and Optional[torch.Tensor] or
float | None and torch.Tensor | None) and import typing.Optional if using that
form; update the function signature accordingly (referencing
recurrent_gated_delta_rule_ref, scale, and initial_state) so the hints comply
with PEP 484.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 28ac11e0-8733-444d-a346-5bd4528d3929
📒 Files selected for processing (6)
benchmarks/bench_blackwell_gdn_prefill.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/blackwell_prefill/gdn.pyflashinfer/gdn_kernels/blackwell_prefill/gdn_helpers.pyflashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.pytests/gdn/test_gdn_prefill_blackwell.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/gdn/test_gdn_prefill_blackwell.py (2)
50-50: Use explicitOptional[float]type hint.PEP 484 prohibits implicit
Optional. The parameter defaultNonerequires explicit optional typing.♻️ Suggested fix
+from typing import Optional + def recurrent_gated_delta_rule_ref( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, g: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, ):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` at line 50, The parameter signature uses an implicit optional type "scale: float = None"; change it to an explicit Optional type by updating the parameter to "scale: Optional[float]" and add "from typing import Optional" to the imports in tests/gdn/test_gdn_prefill_blackwell.py so the annotation is PEP 484-compliant; update any other occurrences of the same signature in that file to match.
208-221: Prefix unusedowith underscore.In state-specific tests, the output tensor
ois unpacked but not used. Prefix with_to indicate it's intentionally ignored.♻️ Suggested fix (apply to lines 208, 348, 485, 634)
- o, state_output = chunk_gated_delta_rule( + _o, state_output = chunk_gated_delta_rule(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 208 - 221, The test unpacks the return value `o` from chunk_gated_delta_rule but never uses it; rename it to `_o` (or simply `_`) in each unpacking to signal the unused tensor and satisfy lint/tests — update occurrences where chunk_gated_delta_rule is called (e.g., the unpacking at the call that currently assigns `o, state_output = chunk_gated_delta_rule(...)`) to `_o, state_output = chunk_gated_delta_rule(...)` (apply the same change to the other similar calls mentioned).benchmarks/bench_blackwell_gdn_prefill.py (1)
153-162: Preferdefover lambda assignment.Ruff E731 flags this lambda assignment. Using a named function improves readability and debuggability.
♻️ Suggested refactor
- fn_fla = lambda: fla_base( - q, - k, - v, - g, - beta, - None, - initial_state=h0, - output_final_state=output_final_state, - ) + def fn_fla(): + return fla_base( + q, + k, + v, + g, + beta, + None, + initial_state=h0, + output_final_state=output_final_state, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 153 - 162, The lambda assigned to fn_fla should be replaced with a named function to satisfy Ruff E731 and improve readability: define a function (e.g., def fn_fla()) that accepts no arguments and calls fla_base(q, k, v, g, beta, None, initial_state=h0, output_final_state=output_final_state) and then use that function in place of the lambda; keep the same references to fla_base, h0, and output_final_state so behavior is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Line 35: Fix the typo in the header comment: change the comment string
"Blacwell GDN prefill" to "Blackwell GDN prefill" so the module description
reads correctly (look for the comment line containing "Blacwell GDN prefill" in
benchmarks/bench_blackwell_gdn_prefill.py).
---
Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 153-162: The lambda assigned to fn_fla should be replaced with a
named function to satisfy Ruff E731 and improve readability: define a function
(e.g., def fn_fla()) that accepts no arguments and calls fla_base(q, k, v, g,
beta, None, initial_state=h0, output_final_state=output_final_state) and then
use that function in place of the lambda; keep the same references to fla_base,
h0, and output_final_state so behavior is unchanged.
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Line 50: The parameter signature uses an implicit optional type "scale: float
= None"; change it to an explicit Optional type by updating the parameter to
"scale: Optional[float]" and add "from typing import Optional" to the imports in
tests/gdn/test_gdn_prefill_blackwell.py so the annotation is PEP 484-compliant;
update any other occurrences of the same signature in that file to match.
- Around line 208-221: The test unpacks the return value `o` from
chunk_gated_delta_rule but never uses it; rename it to `_o` (or simply `_`) in
each unpacking to signal the unused tensor and satisfy lint/tests — update
occurrences where chunk_gated_delta_rule is called (e.g., the unpacking at the
call that currently assigns `o, state_output = chunk_gated_delta_rule(...)`) to
`_o, state_output = chunk_gated_delta_rule(...)` (apply the same change to the
other similar calls mentioned).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b0fdf0a2-3caf-4cae-8abe-675099594482
📒 Files selected for processing (2)
benchmarks/bench_blackwell_gdn_prefill.pytests/gdn/test_gdn_prefill_blackwell.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py (1)
130-133: Consider explicitInt32cast for type consistency.The similar implementation in
grouped_gemm_masked_blackwell.pyexplicitly castsadvance_counttoInt32before arithmetic:self._current_work_linear_idx += Int32(advance_count) * Int32(self.num_persistent_clusters)This ensures type consistency in MLIR code generation. The current implementation may work due to implicit conversion, but explicit casting would be more robust.
♻️ Suggested improvement
def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): if self._is_persistent: - self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._current_work_linear_idx += Int32(advance_count) * self.num_persistent_sm self._is_first_block = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py` around lines 130 - 133, In advance_to_next_work ensure MLIR integer types are explicit by casting advance_count and num_persistent_sm to Int32 before the multiplication and addition to _current_work_linear_idx; specifically, in the method advance_to_next_work (check _is_persistent branch) replace the implicit arithmetic with an explicit Int32(advance_count) * Int32(self.num_persistent_sm) and add that result to _current_work_linear_idx so type consistency matches the grouped_gemm_masked_blackwell pattern.benchmarks/bench_blackwell_gdn_prefill.py (3)
476-478: Consider adding GVA constraint validation.The kernel requires
num_v_heads % num_qk_heads == 0for grouped value attention (GVA). Adding this validation at the CLI level would provide clearer error messages than the kernel warning.♻️ Suggested addition after head_dim check
if args.head_dim != 128: print(f"Error: head_dim must be 128, got {args.head_dim}") sys.exit(1) + + if args.num_v_heads % args.num_qk_heads != 0: + print(f"Error: num_v_heads ({args.num_v_heads}) must be divisible by num_qk_heads ({args.num_qk_heads}) for GVA") + sys.exit(1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 476 - 478, Add a CLI-level validation that enforces the grouped value attention constraint by checking args.num_v_heads % args.num_qk_heads == 0 (after the existing head_dim check); if the condition fails, print a clear error like "Error: num_v_heads must be divisible by num_qk_heads, got {args.num_v_heads} and {args.num_qk_heads}" and call sys.exit(1) so the program exits with a helpful message before hitting the kernel.
153-162: Preferdefover assigned lambda.Per Python style guidelines (PEP 8), named functions should use
defrather than assigning a lambda to a variable.♻️ Suggested fix
- fn_fla = lambda: fla_base( - q, - k, - v, - g, - beta, - None, - initial_state=h0, - output_final_state=output_final_state, - ) + def fn_fla(): + return fla_base( + q, + k, + v, + g, + beta, + None, + initial_state=h0, + output_final_state=output_final_state, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 153 - 162, Replace the assigned lambda fn_fla with a proper named function using def to follow PEP 8: create def fn_fla(): that calls fla_base(q, k, v, g, beta, None, initial_state=h0, output_final_state=output_final_state) and returns its result; update any references to fn_fla unchanged. Ensure parameter capture uses the same outer-scope variables (q, k, v, g, beta, h0, output_final_state) as in the original lambda.
362-364: Consider narrowing exception handling.Catching bare
Exceptionalso catchesKeyboardInterruptandSystemExit, preventing clean script termination. For benchmark robustness while allowing user interrupts, consider catchingRuntimeErroror a more specific exception type.♻️ Suggested fix (also apply to line 384)
- except Exception as e: + except (RuntimeError, ValueError, torch.cuda.CudaError) as e: print(f" FAILED: {e}") torch.cuda.empty_cache()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 362 - 364, The current except block uses a broad "except Exception as e:" which also intercepts KeyboardInterrupt/SystemExit; replace it with a narrower exception handler (e.g., "except RuntimeError as e:") and keep the same cleanup (torch.cuda.empty_cache()) and logging, and apply the same change to the similar handler later in the file (the other except block that prints FAILED and calls torch.cuda.empty_cache()). If you must catch multiple runtime errors, list them explicitly or, if you need to preserve re-raising interrupts, add a guard to re-raise KeyboardInterrupt and SystemExit before handling other exceptions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 142-152: The __new_from_mlir_values__ method reconstructs a
GdnStaticTileScheduler but omits passing the loc and ip context parameters;
update the return to pass the original object's loc and ip (e.g., self.loc and
self.ip) into the GdnStaticTileScheduler constructor so the reconstructed
scheduler preserves the source debug/location context; modify the final return
in __new_from_mlir_values__ of GdnStaticTileScheduler to include loc and ip
along with new_params, new_current_work_linear_idx, new_blk_coord, and
new_grid_shape.
---
Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 476-478: Add a CLI-level validation that enforces the grouped
value attention constraint by checking args.num_v_heads % args.num_qk_heads == 0
(after the existing head_dim check); if the condition fails, print a clear error
like "Error: num_v_heads must be divisible by num_qk_heads, got
{args.num_v_heads} and {args.num_qk_heads}" and call sys.exit(1) so the program
exits with a helpful message before hitting the kernel.
- Around line 153-162: Replace the assigned lambda fn_fla with a proper named
function using def to follow PEP 8: create def fn_fla(): that calls fla_base(q,
k, v, g, beta, None, initial_state=h0, output_final_state=output_final_state)
and returns its result; update any references to fn_fla unchanged. Ensure
parameter capture uses the same outer-scope variables (q, k, v, g, beta, h0,
output_final_state) as in the original lambda.
- Around line 362-364: The current except block uses a broad "except Exception
as e:" which also intercepts KeyboardInterrupt/SystemExit; replace it with a
narrower exception handler (e.g., "except RuntimeError as e:") and keep the same
cleanup (torch.cuda.empty_cache()) and logging, and apply the same change to the
similar handler later in the file (the other except block that prints FAILED and
calls torch.cuda.empty_cache()). If you must catch multiple runtime errors, list
them explicitly or, if you need to preserve re-raising interrupts, add a guard
to re-raise KeyboardInterrupt and SystemExit before handling other exceptions.
In `@flashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py`:
- Around line 130-133: In advance_to_next_work ensure MLIR integer types are
explicit by casting advance_count and num_persistent_sm to Int32 before the
multiplication and addition to _current_work_linear_idx; specifically, in the
method advance_to_next_work (check _is_persistent branch) replace the implicit
arithmetic with an explicit Int32(advance_count) * Int32(self.num_persistent_sm)
and add that result to _current_work_linear_idx so type consistency matches the
grouped_gemm_masked_blackwell pattern.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 11adc012-ba34-4bda-89ae-038b83eab541
📒 Files selected for processing (2)
benchmarks/bench_blackwell_gdn_prefill.pyflashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.py
flashinfer/gdn_kernels/__init__.py
Outdated
| try: | ||
| from .blackwell_prefill.gdn import chunk_gated_delta_rule | ||
| except ImportError: | ||
| chunk_gated_delta_rule = None # type: ignore |
There was a problem hiding this comment.
Can you call this function on SM100 in https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_prefill.py#L86? The chunk_gated_delta_rule function in this file should be the only API for the user to call the prefill gdn kernel. It only supports SM90 on the main branch.
There was a problem hiding this comment.
@hlu1 sure, thanks for pointing this out
There was a problem hiding this comment.
@hlu1 done,change the api to gdn_prefill
| state_output = torch.zeros_like(h0, dtype=torch.float) | ||
|
|
||
| o, state_output = chunk_gated_delta_rule( | ||
| q.clone(), |
There was a problem hiding this comment.
Why do we need these clones here?
There was a problem hiding this comment.
@yzh119 thanks for pointing this out. For unit tests, we do not trust the code under test. Therefore, we cannot guarantee whether these tensors have been incorrectly modified by the function being tested, So i use clones.
And i have remove these clones.
| None, | ||
| state_output, | ||
| ) | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
thanks for pointing this out, done, remove synchronization
| ref_ht = torch.cat((ref_ht, ref_ht_i), dim=0).contiguous() | ||
|
|
||
| ref_ht = torch.transpose(ref_ht, -1, -2).contiguous() | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
ditto and done
| None, | ||
| state_output, | ||
| ) | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
ditto and done
| state_output = torch.zeros_like(h0, dtype=torch.float) | ||
|
|
||
| o, state_output = chunk_gated_delta_rule( | ||
| q.clone(), |
There was a problem hiding this comment.
ditto and done
| None, | ||
| state_output, | ||
| ) | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
ditto and done
| lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g] | ||
| ) | ||
| B, H, T, K, V = *k.shape, v.shape[-1] | ||
| o = torch.zeros(B, H, T, V).to(v) |
There was a problem hiding this comment.
it's much faster to directly create a GPU tensor than moving tensors from CPU to GPU, especially in unittests.
| o = torch.zeros(B, H, T, V).to(v) | |
| o = torch.zeros(B, H, T, V, device=v.device) |
There was a problem hiding this comment.
thanks for pointing this out, done, q k v is GPU tensor
| ): | ||
| """Reference implementation of gated delta rule (recurrent version).""" | ||
| q, k, v, beta, g = map( | ||
| lambda x: x.transpose(1, 2).contiguous().to(torch.float32), [q, k, v, beta, g] |
There was a problem hiding this comment.
it will be much faster if we can pass gpu tensors direclty.
There was a problem hiding this comment.
thanks for pointing this out, done, q k v is GPU tensor
|
|
||
| device = "cuda" | ||
|
|
||
| q = torch.randn((B, T, H, D), dtype=dtype) |
There was a problem hiding this comment.
create these tensors on GPU
There was a problem hiding this comment.
ditto and done
|
|
||
| # Benchmark | ||
| torch.cuda.reset_peak_memory_stats() | ||
| start_event = torch.cuda.Event(enable_timing=True) |
There was a problem hiding this comment.
Please consider using https://docs.flashinfer.ai/generated/flashinfer.testing.bench_gpu_time.html#flashinfer.testing.bench_gpu_time for benchmarking (and enable cupti), see https://github.com/flashinfer-ai/flashinfer/blob/main/.claude/skills/benchmark-kernel/SKILL.md for details.
There was a problem hiding this comment.
@yzh119 thanks for pointing this out, done, use bench_gpu_time func
Change func api position Bench use bench_gpu_time func Update default test config
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)
26-33:⚠️ Potential issue | 🟠 MajorDerive the skip flag from the actual Blackwell backend, not the wrapper import.
Importing
flashinfer.gdn_prefillis not enough here:flashinfer.gdn_kernels.blackwell_prefillcan exportchunk_gated_delta_rule = NoneonImportError, soGDN_BLACKWELL_PREFILL_CUTEDSLcan beTrueeven when the Blackwell DSL backend is missing. Please key this off the actual Blackwell symbol (orpytest.importorskip) so SM100/SM110 builders skip cleanly instead of failing at call time.Also applies to: 82-90
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 26 - 33, The test's skip flag is derived from importing flashinfer.gdn_prefill instead of the actual Blackwell backend symbol; update the check to verify the real backend export (e.g. flashinfer.gdn_kernels.blackwell_prefill.chunk_gated_delta_rule) or use pytest.importorskip("flashinfer.gdn_kernels.blackwell_prefill") to skip when the Blackwell DSL is missing, and set GDN_BLACKWELL_PREFILL_CUTEDSL based on the presence/non-None value of chunk_gated_delta_rule rather than the wrapper import success.
🧹 Nitpick comments (1)
tests/gdn/test_gdn_prefill_blackwell.py (1)
36-41: Please exercise the BF16 path somewhere in this suite.The tests hardcode
torch.float16, so none of the new cases cover the other advertised dtype. A smalltorch.bfloat16parameterization or smoke subset would catch dtype-specific regressions in the Blackwell kernel.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_prefill_blackwell.py` around lines 36 - 41, The test currently hardcodes testtype = torch.float16 so BF16 paths are never exercised; change the test to parametrize testtype over (torch.float16, torch.bfloat16) (or add a small smoke test that sets testtype = torch.bfloat16) and adjust the tolerance logic that depends on testtype (oatol, ortol, satol, srtol) to pick the correct values for each dtype; look for the variable testtype and the tolerance assignments in the test_gdn_prefill_blackwell tests and update them to run/assert for both dtypes (use pytest.mark.parametrize on the surrounding test function or add an additional test function that mirrors the current assertions but uses torch.bfloat16).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Line 38: The benchmark imports the public dispatcher (chunk_gated_delta_rule)
which will trigger Hopper (SM90) kernels on unsupported hardware; add a
fail-fast GPU check at the top of the module before importing or invoking
chunk_gated_delta_rule (and duplicate checks for the other similar import/usage
blocks around the 428-437 region): query the current CUDA device properties
(e.g., via torch.cuda.get_device_properties(device) or the CUDA runtime API),
extract the SM/compute capability, and if it is not the target SM100 or SM110
family raise a clear RuntimeError (or exit) explaining the benchmark requires
SM100/SM110 hardware so the Hopper path is not executed.
In `@flashinfer/gdn_prefill.py`:
- Around line 203-209: The public docstring promises that g=None and beta=None
default to all-ones, but current code defers this to backend dispatch and raises
NotImplementedError on SM100/SM110; eagerly materialize defaults before dispatch
by checking if g or beta is None and, using total_seq_len and num_sab_heads,
create float32 torch.ones tensors (shape [total_seq_len, num_sab_heads]) so the
backend always receives concrete tensors; apply the same change to the other
occurrence mentioned (around the 259-262 region) so behavior matches the docs
across backends.
- Around line 176-177: The API wrapper chunk_gated_delta_rule currently handles
SM90/SM100/SM110 dispatch but lacks the standard backend-requirement contract;
add the `@backend_requirement` decorator to chunk_gated_delta_rule and to the
other compute-gated wrappers in the same region (the functions around the
240-280 block), and attach the required support helpers to each function:
implement and assign is_compute_capability_supported(cc) to mirror the dispatch
logic (return True for SM90/SM100/SM110 where appropriate) and implement
is_backend_supported() to reflect available backends used by the wrapper; ensure
the decorator import is present and the helper functions are attached as
attributes on the function objects so other code can call
function.is_compute_capability_supported(cc) and
function.is_backend_supported().
- Around line 256-277: The branch that calls chunk_gated_delta_rule_blackwell
must first verify the optional Blackwell backend is available (it may be None
when flashinfer.gdn_kernels.blackwell_prefill failed to import); before invoking
chunk_gated_delta_rule_blackwell(q, k, v, g, beta, ...), check that
chunk_gated_delta_rule_blackwell is not None and if it is None raise a clear
NotImplementedError/RuntimeError like "Blackwell CuTe kernel not available on
this build" (keep the existing gate/beta check), so you fail with an explicit
availability error instead of a 'NoneType' object is not callable'.
---
Duplicate comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 26-33: The test's skip flag is derived from importing
flashinfer.gdn_prefill instead of the actual Blackwell backend symbol; update
the check to verify the real backend export (e.g.
flashinfer.gdn_kernels.blackwell_prefill.chunk_gated_delta_rule) or use
pytest.importorskip("flashinfer.gdn_kernels.blackwell_prefill") to skip when the
Blackwell DSL is missing, and set GDN_BLACKWELL_PREFILL_CUTEDSL based on the
presence/non-None value of chunk_gated_delta_rule rather than the wrapper import
success.
---
Nitpick comments:
In `@tests/gdn/test_gdn_prefill_blackwell.py`:
- Around line 36-41: The test currently hardcodes testtype = torch.float16 so
BF16 paths are never exercised; change the test to parametrize testtype over
(torch.float16, torch.bfloat16) (or add a small smoke test that sets testtype =
torch.bfloat16) and adjust the tolerance logic that depends on testtype (oatol,
ortol, satol, srtol) to pick the correct values for each dtype; look for the
variable testtype and the tolerance assignments in the
test_gdn_prefill_blackwell tests and update them to run/assert for both dtypes
(use pytest.mark.parametrize on the surrounding test function or add an
additional test function that mirrors the current assertions but uses
torch.bfloat16).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9b400e8a-8a96-489c-930a-cc00a4715bc2
📒 Files selected for processing (7)
benchmarks/bench_blackwell_gdn_prefill.pyflashinfer/gdn_kernels/__init__.pyflashinfer/gdn_kernels/blackwell_prefill/__init__.pyflashinfer/gdn_kernels/blackwell_prefill/gdn.pyflashinfer/gdn_kernels/blackwell_prefill/gdn_tile_scheduler.pyflashinfer/gdn_prefill.pytests/gdn/test_gdn_prefill_blackwell.py
kaixih
left a comment
There was a problem hiding this comment.
I know this PR is focused on SM100+ prefill support, but just want to clarify: since this prefill kernel outputs FP32 state while the fast decode path requires BF16 state, users will still need an explicit dtype conversion between prefill and decode until gap 5.4 (prefill BF16 state) is addressed. Is this a known limitation or something we plan to tackle soon?
hi @kaixih , this is a known limitation and will be addressed in the next version. |
|
What is preventing us from merging this PR? |
flashinfer/gdn_prefill.py
Outdated
| return output | ||
|
|
||
|
|
||
| @supported_compute_capability([90, 100, 110]) |
There was a problem hiding this comment.
Does this support B300(sm103)?
There was a problem hiding this comment.
Yes this should support sm103
There was a problem hiding this comment.
Thanks for replying. Could we please update @supported_compute_capability to include 103 as well?
|
/bot run |
|
[FAILED] Pipeline #46876098: 1/20 passed |
|
/bot run |
|
[FAILED] Pipeline #46883505: 13/20 passed |
|
/bot run |
|
@dianzhangchen is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
flashinfer/gdn_prefill.py
Outdated
| return output | ||
|
|
||
|
|
||
| @supported_compute_capability([90, 100, 103, 110]) |
There was a problem hiding this comment.
Why do we remove 110 here?
There was a problem hiding this comment.
I haven’t tested sm110 yet, so support is not confirmed.
|
/bot run |
|
/bot run |
|
@dianzhangchen is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
[FAILED] Pipeline #46950430: 9/20 passed |
|
I find that this PR does not perform well for small head configurations. import torch
import torch.nn.functional as F
import numpy as np
from flashinfer.testing import bench_gpu_time
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_gdn
from fla.ops.gated_delta_rule.chunk import chunk_gated_delta_rule_fwd as fla_gdn
CONFIGS = [
# (batch, seq_len, h_qk, h_v, d, label)
# Qwen3.5-397B (h_k=16, h_v=64, d=128) under different TP
(1, 8192, 4, 16, 128, "TP4 B=1 S=8192"),
(1, 4096, 4, 16, 128, "TP4 B=1 S=4096"),
(1, 2048, 4, 16, 128, "TP4 B=1 S=2048"),
(1, 8192, 8, 32, 128, "TP2 B=1 S=8192"),
(1, 8192, 16, 64, 128, "TP1 B=1 S=8192"),
# Original benchmark config (symmetric heads)
(4, 4096, 32, 32, 128, "Symmetric B=4 S=4096"),
]
WARMUP = 5
ITERS = 20
def bench_flashinfer(B, T, h_qk, h_v, d):
device = "cuda"
dtype = torch.float16
q = torch.randn((B, T, h_qk, d), dtype=dtype, device=device)
k = F.normalize(
torch.randn(B, T, h_qk, d, dtype=torch.float32, device=device), p=2, dim=-1
).to(dtype)
v = torch.randn((B, T, h_v, d), dtype=dtype, device=device)
g = F.logsigmoid(torch.rand(1, T * B, h_v, dtype=torch.float32, device=device))
beta = torch.rand(1, T * B, h_v, dtype=torch.float32, device=device).sigmoid()
h0 = torch.randn((B, h_v, d, d), dtype=torch.float32, device=device)
state_out = torch.zeros_like(h0)
fn = lambda: fi_gdn(q, k, v, g, beta, None, h0, True, None, False, None, state_out)
times = bench_gpu_time(fn, enable_cupti=True, dry_run_iters=WARMUP, repeat_iters=ITERS)
torch.cuda.empty_cache()
return np.average(times)
def bench_fla(B, T, h_qk, h_v, d):
device = "cuda"
dtype = torch.float16
# FLA doesn't support GVA (h_qk != h_v), expand q/k to h_v
h = h_v
q = torch.randn((B, T, h, d), dtype=dtype, device=device)
k = F.normalize(
torch.randn(B, T, h, d, dtype=torch.float32, device=device), p=2, dim=-1
).to(dtype)
v = torch.randn((B, T, h_v, d), dtype=dtype, device=device)
g = F.logsigmoid(torch.rand(1, T * B, h_v, dtype=torch.float32, device=device))
beta = torch.rand(1, T * B, h_v, dtype=torch.float32, device=device).sigmoid()
h0 = torch.randn((B, h_v, d, d), dtype=torch.float32, device=device)
fn = lambda: fla_gdn(q, k, v, g, beta, None, initial_state=h0, output_final_state=True)
times = bench_gpu_time(fn, enable_cupti=True, dry_run_iters=WARMUP, repeat_iters=ITERS)
torch.cuda.empty_cache()
return np.average(times)
def main():
print(f"\nGPU: {torch.cuda.get_device_name(0)}")
print(f"Model reference: Qwen3.5-397B-A17B (h_k=16, h_v=64, d=128, GVA ratio=4)")
print()
header = f"{'Config':<30s} {'h_qk':>4s} {'h_v':>4s} {'FlashInfer':>10s} {'FLA/Triton':>10s} {'Speedup':>8s}"
print(header)
print("─" * len(header))
for B, T, h_qk, h_v, d, label in CONFIGS:
fi_ms = bench_flashinfer(B, T, h_qk, h_v, d)
fla_ms = bench_fla(B, T, h_qk, h_v, d)
speedup = fla_ms / fi_ms
marker = "✓" if speedup > 1.0 else "✗"
print(
f"{label:<30s} {h_qk:>4d} {h_v:>4d} {fi_ms:>9.3f}ms {fla_ms:>9.3f}ms {speedup:>7.2f}x {marker}"
)
if __name__ == "__main__":
main() |
hi @ZJY0516 , this is a known issue, as mentioned in the PR: "In the current version, we do not parallelize over the sequence-length dimension, so performance is limited when both batch size and head count are small. In such cases, only part of the 148 SMs are used, so the GPU is not fully utilized." |
📌 Description
Gated Delta Networks (GDN) chunked linear attention kernel for NVIDIA Blackwell (SM100).
Implements the Chunk-wise Gated Delta Rule linear attention using CuTe-DSL, for Blackwell Architecture.
Key Features:
Performance on B200
Issues
🔍 Related Issues
chunk_gated_delta_rule for Blackwell #2340
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests