Skip to content

perf(gdn): fix bf16_state T=1 per-call overhead and add pool+padding …#3118

Draft
ameynaik-hub wants to merge 4 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/bf16_baseline_fix
Draft

perf(gdn): fix bf16_state T=1 per-call overhead and add pool+padding …#3118
ameynaik-hub wants to merge 4 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/bf16_baseline_fix

Conversation

@ameynaik-hub
Copy link
Copy Markdown
Contributor

@ameynaik-hub ameynaik-hub commented Apr 19, 2026

…to ILP kernel

Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and the T=1 dispatch in flashinfer/gdn_decode.py:

  1. Per-call Python overhead fix
    Move torch.arange / torch.empty / from_dlpack / get_device_capability out
    of the steady-state call path. These were previously called on every
    invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
    of CUPTI-visible overhead per call at small BS. All default-tensor
    allocation and dlpack conversion is now done once, inside the
    cache_key not in _compiled_kernels_* block, and cached alongside the
    compiled kernel. Steady-state calls pass raw torch tensors directly to
    the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
    per-call torch.cuda.get_device_capability().

  2. Pool + padding support on the ILP kernel
    gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
    accepts h0_indices and h0_out_indices, matching the MTP kernel's
    signature. Negative indices redirect to pool slot 0 (null buffer);
    writes go to a separate flat_write_idx so input and output pool slots
    can differ. The ILP launcher and gated_delta_rule wrapper thread the
    new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
    collapsed so pool+indices T=1 calls no longer detour through the
    heavier MTP kernel.

Design choice: kernel always takes indices (no constexpr switch).

Benchmark: bench_gdn_decode.py (B200, HV=64, T=1, bf16, qk_l2norm ON)

Command:
python benchmarks/bench_gdn_decode.py
--batch-size 1 4 8 16 32 64 128 256 512
--num-q-heads 32 --num-k-heads 32 --num-v-heads 64
--head-size 128 --dtype bfloat16 --warmup 20 --iters 200

Bf16State column results (us):

BS | time
1 | 3.68
4 | 5.82
8 | 9.30
16 | 15.01
32 | 26.86
64 | 48.22
128 | 89.63
256 | 172.74
512 | 337.89

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Improved BF16 hidden-state handling to fully support pooled state reads/writes driven by per-batch slot indices, including correct handling of padded/negative indices and single-step sequences.
  • Performance

    • Reduced runtime overhead by caching GPU capability detection and reusing compiled kernels and default tensors for steady-state calls.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 19, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9d86d3b8-1509-4ff3-b28f-3cde80e870da

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR routes the BF16 T==1 GDN decode fast path through the pooled-state BF16 kernel, threads optional initial_state_indices/output_state_indices and output through host/kernel paths, updates the ILP kernel for separate read/write pool indices with -1→0 semantics, and caches packed-FMA capability at import time.

Changes

Cohort / File(s) Summary
BF16 T==1 dispatch logic
flashinfer/gdn_decode.py
Route all T==1 BF16 cases through the BF16-state kernel; choose initial_state_source based on use_pool and forward initial_state_indices/output_state_indices and output into the kernel invocation.
BF16 state kernel & host wrapper
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
ILP kernel now accepts h0_indices/h0_out_indices, maps negative indices to slot 0, computes separate read/write flat offsets, and writes updated state to a write-rooted local_tile. Host gated_delta_rule/MTP signatures accept optional initial_state_indices, output_state_indices, output; state reshaping now uses pool_size*HV; compiled callables, default index tensors, and default outputs are cached and reused; GPU packed-FMA capability is detected once at import (_USE_PACKED_FMA).

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant Host as gated_delta_rule (host)
  participant Cache
  participant Kernel as BF16 ILP/MTP Kernel
  participant Pool as PooledStateMemory

  Caller->>Host: call gated_delta_rule(..., initial_state, use_pool, initial_state_indices?, output_state_indices?, output?)
  Host->>Cache: lookup/compile callable & default indices/output
  Cache-->>Host: callable + default tensors
  Host->>Kernel: invoke(callable, state_buf, h0_indices, h0_out_indices, output, config)
  Kernel->>Pool: read from pool using h0_indices (negative->0)
  Kernel->>Pool: write updated state using h0_out_indices (separate offsets)
  Kernel-->>Host: output tensor
  Host-->>Caller: return output (and optionally updated state)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related issues

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • bkryu
  • jimmyzho
  • nvmbreughe
  • kahyunnam

Poem

🐰 I hopped through pools of bfloat delight,
threading indices left and indices right.
Reads find zero when negatives sneak,
writes go their way, no races to speak.
A tiny rabbit cheers: BF16 takes flight! 🥕✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the two main changes: per-call overhead reduction for T=1 BF16 path and pool+padding support for the ILP kernel.
Description check ✅ Passed The PR description provides detailed explanations of both changes, benchmark results, and implementation details, though the template sections are largely unchecked.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces pool-based state management for GDN decode kernels, enabling indexed state access and updates for both ILP and MTP paths. It also optimizes the execution path by caching GPU architecture capabilities and reusing default buffers within the kernel compilation cache. The review feedback identifies a critical bug where sharing a default_output buffer across calls leads to unintended data overwrites, violating standard functional API expectations. Additionally, the reviewer pointed out safety concerns regarding hardcoded device indexing for architecture detection in multi-GPU environments and missing type casting for initial_state_indices which could lead to kernel crashes.

Comment on lines +2535 to +2536
if output is None:
output = cache["output"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Caching and returning a shared default_output buffer is a critical bug for a functional API. Subsequent calls to gated_delta_rule with the same shape will overwrite the results of previous calls if the user does not provide an explicit output tensor. This violates standard PyTorch expectations where a new tensor is returned for each call.

To reduce Python overhead without this side effect, consider allocating a new tensor in the steady state or using a dedicated workspace/pool object if this optimization is strictly necessary.

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ameynaik-hub can you take a look at the AI review bot comments here?

Comment on lines +2340 to +2341
_GPU_MAJOR, _ = torch.cuda.get_device_capability(0)
_USE_PACKED_FMA = _GPU_MAJOR >= 10
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding the GPU architecture detection to device(0) at the module level is unsafe in multi-GPU environments where different GPU generations (e.g., SM90 and SM100) might coexist. If device(0) is SM100 but the kernel is called on an SM90 device, it will attempt to use packed FMA instructions that are not supported, leading to a crash.

Consider caching the capability in a dictionary mapping device IDs to their respective capabilities to maintain performance while ensuring correctness across multiple devices.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ameynaik-hub can you take a look here? This might also have to be torch.cuda.current_device().

Comment on lines +2419 to +2421
use_pool = initial_state_indices is not None
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While output_state_indices is converted to int32, initial_state_indices is not. If a user passes int64 indices (the default for torch.arange or long tensors), it may cause a crash or undefined behavior within the CuTe kernel which expects 32-bit integers.

Suggested change
use_pool = initial_state_indices is not None
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)
use_pool = initial_state_indices is not None
if use_pool and initial_state_indices.dtype != torch.int32:
initial_state_indices = initial_state_indices.to(torch.int32)
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gdn_decode.py (2)

211-238: ⚠️ Potential issue | 🟠 Major

Validate and normalize initial_state_indices before BF16 dispatch.

Line 274 now forwards initial_state_indices into the BF16 T=1 backend, but the pool path only validates initial_state. A wrong-shape, CPU, or non-int32 read-index tensor can reach the kernel unchecked.

Suggested validation
     if output_state_indices is not None:
         assert use_pool, (
             "output_state_indices can only be used with initial_state (pool mode)"
         )
@@
         assert output_state_indices.dtype in (torch.int32, torch.int64), (
             f"output_state_indices must be int32 or int64, "
             f"got {output_state_indices.dtype}"
         )
+        assert output_state_indices.device == q.device, (
+            f"output_state_indices must be on {q.device}, "
+            f"got {output_state_indices.device}"
+        )
+        if output_state_indices.dtype != torch.int32:
+            output_state_indices = output_state_indices.to(torch.int32)
 
     if use_pool:
+        assert initial_state_indices is not None
+        assert initial_state_indices.shape == (B,), (
+            f"Expected initial_state_indices shape [{B}], "
+            f"got {initial_state_indices.shape}"
+        )
+        assert initial_state_indices.dtype in (torch.int32, torch.int64), (
+            f"initial_state_indices must be int32 or int64, "
+            f"got {initial_state_indices.dtype}"
+        )
+        assert initial_state_indices.device == q.device, (
+            f"initial_state_indices must be on {q.device}, "
+            f"got {initial_state_indices.device}"
+        )
+        if initial_state_indices.dtype != torch.int32:
+            initial_state_indices = initial_state_indices.to(torch.int32)
         pool_size = initial_state.shape[0]

Also applies to: 274-276

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 211 - 238, The code fails to
validate/normalize initial_state_indices before dispatching to the BF16 T=1
backend: ensure that when use_pool is True you validate initial_state_indices
(shape == (B,), dtype int32 or int64, device matches initial_state/device, and
is contiguous) and then normalize/convert it to the kernel-expected type/device
(e.g., move to the same device as initial_state and convert to torch.int32 or
torch.int64 as required) before passing it into the BF16 backend; update the
logic around initial_state_indices and the BF16 dispatch (the variables
initial_state_indices, use_pool and the BF16 T=1 backend call) so a malformed,
CPU, or wrong-dtype index tensor cannot reach the kernel.

264-279: ⚠️ Potential issue | 🟠 Major

Thread compatible caller-provided output into the BF16 backend.

The backend now accepts output, but this wrapper still writes into the backend’s cached default buffer and copies afterward. That preserves an avoidable scratch-buffer race for concurrent calls and adds the copy back on the hot path.

Suggested forwarding pattern
         assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}"
         scale_val = K**-0.5 if scale is None else scale
+        backend_output = (
+            output
+            if output is not None
+            and output.shape == (B, T, HV, V)
+            and output.device == q.device
+            and output.dtype == q.dtype
+            else None
+        )
         if T == 1:
@@
                 output_state_indices=output_state_indices,
+                output=backend_output,
                 use_qk_l2norm_in_kernel=use_qk_l2norm,
                 scale=scale_val,
             )
         else:
@@
                 output_state_indices=output_state_indices,
+                output=backend_output,
                 use_qk_l2norm_in_kernel=use_qk_l2norm,
                 scale=scale_val,
             )

Also applies to: 282-297

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 264 - 279, The wrapper is calling
_gated_delta_rule_bf16_state but still writes into the backend’s cached default
buffer and copies out, causing a race and extra copy; update the call sites in
gdn_decode.py (the current _gated_delta_rule_bf16_state invocation and the
similar invocation later) to accept a caller-provided output buffer and forward
that buffer into the backend’s output parameter instead of relying on the
backend’s cached default, removing the post-call copy; ensure you thread the
same caller-provided output variable through the wrapper signature and pass it
into the backend call (preserve existing args like initial_state_source,
initial_state_indices, output_state_indices, use_qk_l2norm_in_kernel, scale) so
the backend writes directly into the caller’s buffer.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2474-2536: The code stores reusable tensors in
_compiled_kernels_ilp[cache_key] (keys "output" and "default_indices") and then
returns those directly when output or initial_state_indices/output_state_indices
are None, which exposes a mutable cached buffer; change the assignments so
callers get fresh tensors: when setting output = cache["output"] replace with a
safe copy (e.g., output = cache["output"].clone().detach() or allocate a new
empty tensor and copy_into it) and likewise set
initial_state_indices/output_state_indices = cache["default_indices"].clone()
(or .to(...).clone() if dtype/device adjustments are needed); update the code
paths around cache_key, cache["output"], cache["default_indices"],
initial_state_indices, output_state_indices, and output to return clones instead
of the cached objects.
- Around line 2419-2421: Normalize initial_state_indices to torch.int32 the same
way output_state_indices is normalized: when initial_state_indices is not None
and its dtype is not torch.int32, call .to(torch.int32) to convert it before use
(look for the variable initial_state_indices near the existing
output_state_indices conversion and the use_pool logic). Apply the same change
in both occurrences (around the blocks referencing initial_state_indices at the
current location and again near lines ~2639-2640) so the CuTe kernel always
receives int32 tensors.
- Around line 2445-2448: The reshaped pooled state h0_source (created from
initial_state_source.reshape(pool_size * HV, V, K)) can produce a non-contiguous
copy so kernel updates may not reach the caller; before passing h0_source into
from_dlpack()/the kernel ensure it is contiguous by calling .contiguous() when
needed (mirror the intermediate_states pattern), i.e., after computing h0_source
replace or assign it with a contiguous tensor if h0_source.is_contiguous() is
False so the dispatched kernel updates the original pooled buffer; apply the
same contiguity check/fix in gated_delta_rule_mtp() where the identical reshape
occurs.

---

Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 211-238: The code fails to validate/normalize
initial_state_indices before dispatching to the BF16 T=1 backend: ensure that
when use_pool is True you validate initial_state_indices (shape == (B,), dtype
int32 or int64, device matches initial_state/device, and is contiguous) and then
normalize/convert it to the kernel-expected type/device (e.g., move to the same
device as initial_state and convert to torch.int32 or torch.int64 as required)
before passing it into the BF16 backend; update the logic around
initial_state_indices and the BF16 dispatch (the variables
initial_state_indices, use_pool and the BF16 T=1 backend call) so a malformed,
CPU, or wrong-dtype index tensor cannot reach the kernel.
- Around line 264-279: The wrapper is calling _gated_delta_rule_bf16_state but
still writes into the backend’s cached default buffer and copies out, causing a
race and extra copy; update the call sites in gdn_decode.py (the current
_gated_delta_rule_bf16_state invocation and the similar invocation later) to
accept a caller-provided output buffer and forward that buffer into the
backend’s output parameter instead of relying on the backend’s cached default,
removing the post-call copy; ensure you thread the same caller-provided output
variable through the wrapper signature and pass it into the backend call
(preserve existing args like initial_state_source, initial_state_indices,
output_state_indices, use_qk_l2norm_in_kernel, scale) so the backend writes
directly into the caller’s buffer.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 94e11a71-62c6-424c-a912-8bce379e813f

📥 Commits

Reviewing files that changed from the base of the PR and between 8559397 and 210eb20b0707e1c009e1c2e0362be8cc2c77da73.

📒 Files selected for processing (2)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Comment on lines +2419 to +2421
use_pool = initial_state_indices is not None
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the file
fd "gdn_decode_bf16_state.py" -t f

Repository: flashinfer-ai/flashinfer

Length of output: 115


🏁 Script executed:

# Check file size and read the relevant sections
wc -l flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 120


🏁 Script executed:

# Read around lines 2419-2421 to verify the snippet
sed -n '2415,2430p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 702


🏁 Script executed:

# Read around lines 2639-2640 (the "Also applies to" location)
sed -n '2635,2650p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 685


🏁 Script executed:

# Search for function definitions and initial_state_indices usage around these lines
sed -n '2400,2450p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 1945


🏁 Script executed:

# Find the function definition that contains lines 2419-2421
sed -n '2380,2420p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 1506


🏁 Script executed:

# Search for function signature and docstring
sed -n '2300,2450p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py | grep -A 100 "def " | head -150

Repository: flashinfer-ai/flashinfer

Length of output: 4154


🏁 Script executed:

# Look for function signature around line 2400
rg "def.*initial_state_indices" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -A 50 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what function is at line 2419 by looking at function definitions
sed -n '2350,2445p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 3548


🏁 Script executed:

# Check the secondary location at lines 2639-2640
sed -n '2625,2655p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 1323


🏁 Script executed:

# Search for gated_delta_rule_mtp function signature
rg "def gated_delta_rule_mtp" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -A 40

Repository: flashinfer-ai/flashinfer

Length of output: 1717


🏁 Script executed:

# Look for kernel calls and dtype conversions in the file
rg "\.to\(torch\.int32\)" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 534


🏁 Script executed:

# Search for CuTe or CUDA kernel bindings to see dtype requirements
rg "initial_state_indices" flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 907


🏁 Script executed:

# Look for kernel launch with these indices to see if dtype matters
rg "_compiled_kernels_ilp\[" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -A 5 -B 5 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 888


🏁 Script executed:

# Check if there are any other dtype conversions for initial_state_indices
rg "initial_state_indices.*int32|initial_state_indices.*dtype" flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 213


🏁 Script executed:

# Look at the gated_delta_rule_mtp function body to see if it converts initial_state_indices
sed -n '2540,2700p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py | grep -A 30 "def gated_delta_rule_mtp" | tail -40

Repository: flashinfer-ai/flashinfer

Length of output: 1152


🏁 Script executed:

# Get the full gated_delta_rule_mtp body around the indices handling
sed -n '2540,2780p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 8087


Normalize initial_state_indices to int32 like output_state_indices.

Both functions accept initial_state_indices documented as int32 in their docstrings, but only normalize output_state_indices when passed with a different dtype. If callers provide int64 tensors for initial_state_indices, they reach the CuTe kernel with a mismatched dtype. This should be handled consistently:

Suggested fix
     use_pool = initial_state_indices is not None
+    if initial_state_indices is not None and initial_state_indices.dtype != torch.int32:
+        initial_state_indices = initial_state_indices.to(torch.int32)
     if output_state_indices is not None and output_state_indices.dtype != torch.int32:
         output_state_indices = output_state_indices.to(torch.int32)

Also applies to: 2639-2640

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2419 - 2421,
Normalize initial_state_indices to torch.int32 the same way output_state_indices
is normalized: when initial_state_indices is not None and its dtype is not
torch.int32, call .to(torch.int32) to convert it before use (look for the
variable initial_state_indices near the existing output_state_indices conversion
and the use_pool logic). Apply the same change in both occurrences (around the
blocks referencing initial_state_indices at the current location and again near
lines ~2639-2640) so the CuTe kernel always receives int32 tensors.

Comment on lines +2445 to +2448
# Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K];
# pool [pool_size, HV, V, K] -> [pool_size*HV, V, K].
pool_size = initial_state_source.shape[0] if use_pool else B
h0_source = initial_state_source.reshape(pool_size * HV, V, K)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp/repo && find . -name "gdn_decode_bf16_state.py" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 127


🏁 Script executed:

wc -l /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 156


🏁 Script executed:

sed -n '2440,2455p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 168


🏁 Script executed:

sed -n '2635,2650p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 168


🏁 Script executed:

sed -n '2400,2500p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 168


🏁 Script executed:

sed -n '2600,2700p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 168


🏁 Script executed:

find . -name "gdn_decode_bf16_state.py" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 117


🏁 Script executed:

git ls-files | grep gdn_decode_bf16_state

Repository: flashinfer-ai/flashinfer

Length of output: 115


🏁 Script executed:

git ls-files | grep -i gdn

Repository: flashinfer-ai/flashinfer

Length of output: 902


🏁 Script executed:

wc -l ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 122


🏁 Script executed:

sed -n '2440,2455p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 681


🏁 Script executed:

sed -n '2635,2650p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 685


🏁 Script executed:

sed -n '2400,2500p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 4131


🏁 Script executed:

sed -n '2350,2420p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 2597


🏁 Script executed:

grep -n "def gated_delta_rule" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 125


🏁 Script executed:

sed -n '2380,2450p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 2716


🏁 Script executed:

rg "initial_state_source" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 3829


🏁 Script executed:

sed -n '2445,2520p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 3055


🏁 Script executed:

sed -n '2642,2700p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 2386


🏁 Script executed:

grep -n "from_dlpack\|h0_source" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1710


🏁 Script executed:

rg "gated_delta_rule" ./flashinfer/gdn_kernels/ --type py -l

Repository: flashinfer-ai/flashinfer

Length of output: 251


🏁 Script executed:

cat ./tests/gdn/test_decode_pretranspose_noncontiguous_pool.py

Repository: flashinfer-ai/flashinfer

Length of output: 3739


🏁 Script executed:

rg "is_contiguous" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 119


🏁 Script executed:

sed -n '2505,2550p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 1168


🏁 Script executed:

sed -n '1,50p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 1550


🏁 Script executed:

rg "gated_delta_rule_decode_pretranspose" ./flashinfer/gdn_decode.py -A 10 | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 694


🏁 Script executed:

rg "def gated_delta_rule_decode_pretranspose" ./flashinfer -r '$0' --type py

Repository: flashinfer-ai/flashinfer

Length of output: 136


🏁 Script executed:

grep -n "def gated_delta_rule" ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 176


🏁 Script executed:

sed -n '110,200p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 4988


🏁 Script executed:

sed -n '403,500p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 4085


🏁 Script executed:

sed -n '544,650p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 4581


🏁 Script executed:

python3 << 'EOF'
import torch

# Simulate the test pattern
page_gap = 2
pool_size = 8
HV = 32
V = 128
K = 128

# Create non-contiguous pool like test does
pool_storage = torch.randn(pool_size, page_gap, HV, V, K, dtype=torch.float32)
pool_source = pool_storage[:, page_gap - 1]

print(f"pool_source is_contiguous: {pool_source.is_contiguous()}")
print(f"pool_source shape: {pool_source.shape}")
print(f"pool_source strides: {pool_source.strides}")

# Try to reshape like the kernel does
h0_source = pool_source.reshape(pool_size * HV, V, K)
print(f"h0_source is_contiguous: {h0_source.is_contiguous()}")

# Check if it's a view or a copy
print(f"pool_source data_ptr: {pool_source.data_ptr()}")
print(f"h0_source data_ptr: {h0_source.data_ptr()}")
print(f"Is h0_source a view of pool_source? {h0_source.data_ptr() == pool_source.data_ptr()}")

# Check if modifying h0_source affects pool_source
h0_source[0, 0, 0] = 999.0
print(f"pool_source[0, 0, 0, 0] after modification: {pool_source[0, 0, 0, 0]}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 185


🏁 Script executed:

cd ./tests/gdn && python -m pytest test_decode_pretranspose_noncontiguous_pool.py -v 2>&1 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

sed -n '2660,2675p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 430


🏁 Script executed:

sed -n '2505,2545p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 1064


🏁 Script executed:

sed -n '200,250p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 2306


🏁 Script executed:

rg "initial_state.is_contiguous\|initial_state_source.is_contiguous" ./flashinfer/gdn_kernels/

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '170,210p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 2411


🏁 Script executed:

sed -n '2364,2380p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 695


🏁 Script executed:

sed -n '225,245p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 945


🏁 Script executed:

sed -n '260,310p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 2155


🏁 Script executed:

grep -n "def gated_delta_rule\|def gated_delta_rule_mtp" ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 176


🏁 Script executed:

sed -n '250,280p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 1303


🏁 Script executed:

sed -n '2380,2410p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 1141


🏁 Script executed:

sed -n '2600,2625p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 1064


🏁 Script executed:

rg "contiguous\|is_contiguous" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg "stride\(-1\)" ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 115


🏁 Script executed:

sed -n '2658,2670p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 451


🏁 Script executed:

rg "\.reshape\(" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py | grep -n "source"

Repository: flashinfer-ai/flashinfer

Length of output: 205


🏁 Script executed:

sed -n '2620,2680p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 2026


🏁 Script executed:

sed -n '238,242p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 348


🏁 Script executed:

rg "initial_state_source.*contiguous\|initial_state_source.*reshape" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '2640,2680p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py

Repository: flashinfer-ai/flashinfer

Length of output: 1424


🏁 Script executed:

sed -n '235,245p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 571


🏁 Script executed:

python3 << 'EOF'
# Analyze whether K-contiguity guarantee prevents copy on reshape

# For a tensor of shape [pool_size, HV, V, K] with stride(-1) == 1,
# when we reshape to [pool_size*HV, V, K], PyTorch checks if the
# new shape is compatible with the old strides.

# K-contiguous means stride[-1] == 1 (elements in K are adjacent)
# But elements along other dimensions may not be adjacent.

# Example: [8, 32, 128, 128] with page stride on dim 0
# Strides might be: (4194304, 131072, 1024, 1) [if dim 1 has page gap]
# Reshaping to [256, 128, 128] requires view to be compatible
# Since dim 0 and 1 don't form a contiguous block, reshape may copy

print("K-contiguity (stride[-1]==1) is NOT sufficient to avoid copies on reshape")
print("PyTorch reshape needs compatible strides for all dimensions being merged")
print("When merging pool_size and HV dimensions, non-contiguous strides can force copy")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 294


🏁 Script executed:

rg "contiguous" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 196


🏁 Script executed:

sed -n '280,310p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 1296


🏁 Script executed:

sed -n '255,275p' ./flashinfer/gdn_decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 996


Add contiguity check for reshaped pooled state before kernel dispatch.

The reshape(pool_size * HV, V, K) operation can silently allocate a copy when initial_state_source has non-contiguous strides across the pool_size and HV dimensions, even though the public API enforces only K-contiguity (stride[-1] == 1). The kernel then updates the copy and the caller's pool remains unchanged, violating the documented in-place update contract.

Adopt the pattern already applied to intermediate_states (lines 2657–2658): explicitly check and make contiguous before passing to from_dlpack().

    # Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K];
    # pool [pool_size, HV, V, K] -> [pool_size*HV, V, K].
    pool_size = initial_state_source.shape[0] if use_pool else B
    h0_source = initial_state_source.reshape(pool_size * HV, V, K)
+   if not h0_source.is_contiguous():
+       h0_source = h0_source.contiguous()

Apply the same fix at line 2642–2643 in gated_delta_rule_mtp().

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K];
# pool [pool_size, HV, V, K] -> [pool_size*HV, V, K].
pool_size = initial_state_source.shape[0] if use_pool else B
h0_source = initial_state_source.reshape(pool_size * HV, V, K)
# Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K];
# pool [pool_size, HV, V, K] -> [pool_size*HV, V, K].
pool_size = initial_state_source.shape[0] if use_pool else B
h0_source = initial_state_source.reshape(pool_size * HV, V, K)
if not h0_source.is_contiguous():
h0_source = h0_source.contiguous()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2445 - 2448,
The reshaped pooled state h0_source (created from
initial_state_source.reshape(pool_size * HV, V, K)) can produce a non-contiguous
copy so kernel updates may not reach the caller; before passing h0_source into
from_dlpack()/the kernel ensure it is contiguous by calling .contiguous() when
needed (mirror the intermediate_states pattern), i.e., after computing h0_source
replace or assign it with a contiguous tensor if h0_source.is_contiguous() is
False so the dispatched kernel updates the original pooled buffer; apply the
same contiguity check/fix in gated_delta_rule_mtp() where the identical reshape
occurs.

Comment on lines +2474 to +2536
default_output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)
default_indices = torch.arange(B, dtype=torch.int32, device=q.device)

q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True)
k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True)
v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True)
a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True)
b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True)
A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True)
dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True)
h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True)
o_ = from_dlpack(default_output, assumed_align=32, enable_tvm_ffi=True)
h0_idx_ = from_dlpack(default_indices, assumed_align=32, enable_tvm_ffi=True)
h0_out_idx_ = from_dlpack(
default_indices, assumed_align=32, enable_tvm_ffi=True
)

# Use maxrregcount=64 for smaller tile_v to improve occupancy
# when grid size is small (fewer waves)
if tile_v < 128:
compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3 --ptxas-options=-maxrregcount=64"
else:
compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3"
_compiled_kernels_ilp[cache_key] = cute.compile(
run_gdn_decode_bf16state_ilp,
h_,
A_log_,
a_,
dt_bias_,
q_,
k_,
v_,
b_,
o_,
softplus_beta,
softplus_threshold,
scale,
HV,
B,
H,
K,
V,
use_qk_l2norm_in_kernel,
use_packed_fma,
tile_v,
stream,
options=compile_opts,
)
_compiled_kernels_ilp[cache_key] = {
"compiled": cute.compile(
run_gdn_decode_bf16state_ilp,
h_,
A_log_,
a_,
dt_bias_,
q_,
k_,
v_,
b_,
o_,
h0_idx_,
h0_out_idx_,
softplus_beta,
softplus_threshold,
scale,
HV,
B,
H,
K,
V,
use_qk_l2norm_in_kernel,
use_packed_fma,
tile_v,
stream,
options=compile_opts,
),
"output": default_output,
"default_indices": default_indices,
}

cache = _compiled_kernels_ilp[cache_key]

if initial_state_indices is None:
initial_state_indices = cache["default_indices"]
if output_state_indices is None:
output_state_indices = initial_state_indices
if output is None:
output = cache["output"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid returning a cached mutable output buffer.

When output is None, these functions return cache["output"]. A later same-shape call overwrites the previous result tensor, so callers that keep the returned output can observe it mutate unexpectedly.

Safer output handling
-        _compiled_kernels_ilp[cache_key] = {
+        _compiled_kernels_ilp[cache_key] = {
             "compiled": cute.compile(
@@
             ),
-            "output": default_output,
             "default_indices": default_indices,
         }
@@
     if output is None:
-        output = cache["output"]
+        output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)
-        _compiled_kernels_mtp[cache_key] = {
+        _compiled_kernels_mtp[cache_key] = {
             "compiled": cute.compile(
@@
             ),
             "default_indices": default_indices,
-            "output": default_output,
         }
@@
     if output is None:
-        output = cache["output"]
+        output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)

Also applies to: 2693-2754

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2474 - 2536,
The code stores reusable tensors in _compiled_kernels_ilp[cache_key] (keys
"output" and "default_indices") and then returns those directly when output or
initial_state_indices/output_state_indices are None, which exposes a mutable
cached buffer; change the assignments so callers get fresh tensors: when setting
output = cache["output"] replace with a safe copy (e.g., output =
cache["output"].clone().detach() or allocate a new empty tensor and copy_into
it) and likewise set initial_state_indices/output_state_indices =
cache["default_indices"].clone() (or .to(...).clone() if dtype/device
adjustments are needed); update the code paths around cache_key,
cache["output"], cache["default_indices"], initial_state_indices,
output_state_indices, and output to return clones instead of the cached objects.

…to ILP kernel

Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:

1) Per-call Python overhead fix
   Move torch.arange / torch.empty / from_dlpack / get_device_capability out
   of the steady-state call path. These were previously called on every
   invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
   of CUPTI-visible overhead per call at small BS. All default-tensor
   allocation and dlpack conversion is now done once, inside the
   `cache_key not in _compiled_kernels_*` block, and cached alongside the
   compiled kernel. Steady-state calls pass raw torch tensors directly to
   the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
   per-call torch.cuda.get_device_capability().

2) Pool + padding support on the ILP kernel
   gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
   accepts h0_indices and h0_out_indices, matching the MTP kernel's
   signature. Negative indices redirect to pool slot 0 (null buffer);
   writes go to a separate flat_write_idx so input and output pool slots
   can differ. The ILP launcher and gated_delta_rule wrapper thread the
   new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
   collapsed so pool+indices T=1 calls no longer detour through the
   heavier MTP kernel.

Design choice: kernel always takes indices (no constexpr switch).

Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200

Command:
    python benchmarks/bench_gdn_decode.py \
        --batch-size 1 4 8 16 32 64 128 256 512 \
        --num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
        --head-size 128 --dtype bfloat16 --warmup 20 --iters 200

Bf16State column results (us):

  BS | time
   1 |   3.71
   4 |   5.89
   8 |   9.18
  16 |  14.98
  32 |  26.56
  64 |  48.24
 128 |  89.66
 256 | 172.35
 512 | 337.60

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub force-pushed the ameyn/bf16_baseline_fix branch from 210eb20 to 31288e5 Compare April 19, 2026 02:32
@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

Mirroring Failed

Failed to mirror PR to GitLab. Check logs for details.

@yongwww
Copy link
Copy Markdown
Member

yongwww commented Apr 20, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !571 has been created, and the CI pipeline #49010651 is currently running. I'll report back once the pipeline job completes.

@yongwww
Copy link
Copy Markdown
Member

yongwww commented Apr 20, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #49010651 has been cancelled.

@yongwww
Copy link
Copy Markdown
Member

yongwww commented Apr 20, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !571 has been updated with latest changes, and the CI pipeline #49010926 is currently running. I'll report back once the pipeline job completes.

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !571 has been updated with latest changes, and the CI pipeline #49126615 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, but left a few comments

@@ -2294,6 +2335,11 @@
# Number of SMs on target GPU (detected dynamically)
NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count
Copy link
Copy Markdown
Member

@kahyunnam kahyunnam Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this is not part of the PR; but its kind of related to the changes below. This also probably should be the current device, not index 0; we can re-use this utils function which also caches: https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/cute_dsl/utils.py#L95-L98

Comment on lines +2340 to +2341
_GPU_MAJOR, _ = torch.cuda.get_device_capability(0)
_USE_PACKED_FMA = _GPU_MAJOR >= 10
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ameynaik-hub can you take a look here? This might also have to be torch.cuda.current_device().

Comment on lines +2535 to +2536
if output is None:
output = cache["output"]
Copy link
Copy Markdown
Member

@kahyunnam kahyunnam Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ameynaik-hub can you take a look at the AI review bot comments here?

# PUBLIC API
# ==============================================================================
_compiled_kernels: dict = {}
_compiled_kernels_ilp: dict = {}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a blocker for this PR, but I'm not 100% following what is going on here; manually managing a cache dictionary seems to do the same thing as @functools.cache.

Is there a reason we're not following the @functools.cache pattern like the other cute dsl backends, like here? https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_kernels/gdn_decode_pretranspose.py#L886-L906

@ameynaik-hub ameynaik-hub marked this pull request as draft April 28, 2026 06:40
@ameynaik-hub ameynaik-hub mentioned this pull request Apr 28, 2026
5 tasks
kahyunnam pushed a commit that referenced this pull request May 6, 2026
<!-- .github/pull_request_template.md -->

  ## Summary

  Replaces the legacy `gdn_decode_bf16state_cooprow_kernel` and the
  `gdn_decode_bf16state_mtp_kernel` (ILP=8) with a new
  **`gdn_wide_vec_kernel`** (LDG.E.128 / STG.E.128 fast path) plus a
  small-batch `mtp_ilp4` fallback. Drops ~1900 LOC of dead/unused code,
  adds split-pool support (#2905-compatible) to both surviving BF16
  kernels, and ships the OOB fix mirroring upstream PR #3145 — for the
  BF16 kernels that survived the cleanup.

  **Supersedes #3118.** That PR's perf delta (T=1 per-call overhead +
  pool+padding for the ILP kernel) is the first commit on this branch
  (`8a6e9819`). 
  
  ## What changes

  - **New kernel**: `gdn_wide_vec_kernel` — 128 threads/CTA = 8 groups
    × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per
    thread. Configurable `tile_v ∈ {32, 64, 128}` so the kernel covers
    small/medium/large `B*HV` work-unit sizes uniformly.
  - **Pool-only**: BF16 GDN dispatch is strictly pool-mode (matches
    the production serving contract). Wrapper
    `gated_delta_rule_decode_pretranspose` auto-promotes legacy non-pool
    callers internally — public API unchanged.
  - **Split-pool support** (PR #2905 contract): both surviving BF16
kernels (`gdn_wide_vec_kernel`, `gdn_decode_bf16state_mtp_ilp4_kernel`)
    natively support `output_state_indices != initial_state_indices`,
    with bit-equivalent single-pool behavior selected at compile time
    via `Constexpr[bool] same_pool` for zero-overhead dispatch.
  - **OOB fix (PR #3145 equivalent)**: `intermediate_states` is indexed
    by the per-call batch index `i_n` (not the pool-scoped `cache_idx`),
    so the buffer can be sized `[B, T, HV, V, K]` as production callers
    expect. Regression test catches the bug; pre-fix triggers
    `cudaErrorIllegalAddress` in <2 s.

  ## Removed (~1900 LOC of dead code)

  | Kernel | Why removed |
  |---|---|
| `gdn_decode_bf16state_cooprow_kernel` (~280 LOC) | Replaced by
wide_vec + ILP=4 MTP; had known correctness issues at small batch |
| `gdn_decode_bf16state_ilp_kernel` (~740 LOC) | Only reachable at HV<32
with B≥16 — not a Qwen3.5 shape; MTP path covers it |
| `gdn_decode_bf16state_mtp_kernel` (ILP=8) (~940 LOC) | After wide_vec
extension to split-pool + tile_v=32, mtp_kernel was unreachable |

  End-state BF16 surface = **2 `@cute.kernel`s in one file**:
  - `gdn_wide_vec_kernel` — production hot path
  - `gdn_decode_bf16state_mtp_ilp4_kernel` — small-batch fallback

  Both pool-only, both split-pool capable, both indexed batch-scoped.

  ## Speedup vs previous baseline

Baseline = pre-wide_vec dispatch (the `mtp_kernel` ILP=8 path, captured
  on this same branch by monkey-patching `_select_wide_vec_tile_v` to
  return `None` for every shape). Same harness, same hardware, same
  config — so the comparison isolates the kernel-level speedup that
  wide_vec + the cleanup deliver.

  Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50,
  T=1 invoked with `--update-state`, T≥2 invoked with
  `--cache-intermediate-states`. Kernel time in microseconds (CUPTI).

  ### Speedup (×, baseline / post-PR)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |
|-----|-------|-------|-------|-------|-------|-------|-------|-------|
| 1 | 1.03× | 1.04× | 1.03× | 1.03× | 1.00× | 1.02× | 1.02× | 1.01× |
| 4 | 0.97× | 1.23× | 1.10× | 1.12× | 1.11× | 1.11× | 1.14× | 1.14× |
| 8 | 1.08× | 1.11× | 1.11× | 1.12× | 1.13× | 1.15× | 1.15× | 1.14× |
| 16 | 1.04× | 1.09× | 1.11× | 1.13× | 1.13× | 1.12× | 1.11× | 1.10× |
| 32 | 1.06× | 1.12× | 1.10× | 1.11× | 1.10× | 1.09× | 1.09× | 1.09× |
| 64 | 1.04× | 1.11× | 1.08× | 1.09× | 1.06× | 1.07× | 1.08× | 1.06× |
| 128 | 1.04× | 1.11× | 1.07× | 1.09× | 1.06× | 1.06× | 1.07× | 1.07× |
| 256 | 1.04× | 1.11× | 1.07× | 1.09× | 1.07× | 1.06× | 1.07× | 1.07× |

  ### Time reduction (%)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |

|-----|-------|--------|-------|--------|--------|--------|--------|--------|
| 1 | +2.8% | +3.5% | +3.2% | +3.3% | +0.1% | +1.7% | +1.8% | +0.7% |
| 4 | −3.2% | +18.7% | +9.5% | +10.7% | +9.7% | +10.2% | +12.3% | +12.5%
|
| 8 | +7.8% | +10.1% | +10.0%| +10.7% | +11.8% | +12.7% | +12.8% |
+12.4% |
| 16 | +3.4% | +8.5% | +10.0%| +11.2% | +11.4% | +10.7% | +10.3% | +9.4%
|
| 32 | +6.0% | +10.4% | +9.4% | +9.6% | +8.8% | +8.4% | +8.4% | +7.8% |
| 64 | +4.0% | +10.3% | +7.5% | +8.6% | +5.9% | +6.3% | +7.2% | +5.9% |
| 128 | +4.2% | +9.8% | +6.5% | +8.4% | +6.0% | +5.9% | +6.3% | +6.6% |
| 256 | +4.2% | +9.6% | +6.6% | +8.4% | +6.3% | +5.9% | +6.5% | +6.6% |

  ### Headline

- **T=1 production decode (B≥16)**: 4–6 % time reduction across the full
batch sweep — the Qwen3.5 hot path.
- **T≥2 with cache=ON (B≥4)**: 6–18 % time reduction at every shape.
Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×).
- **Tiny shapes (B=1)**: within ±3 % of baseline (kernel isn't
DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was
already efficient there).

  ### Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200)

  |  B  | T=1  | T=2  | T=3  | T=4  | T=5  | T=6  | T=7  | T=8  |
  |-----|------|------|------|------|------|------|------|------|
  |   1 | 1.25 | 1.21 | 1.49 | 1.63 | 1.45 | 1.55 | 1.64 | 1.70 |
  |   4 | 2.83 | 3.24 | 3.54 | 3.78 | 3.53 | 3.68 | 3.84 | 3.95 |
  |   8 | 3.97 | 4.09 | 4.40 | 4.54 | 4.42 | 4.55 | 4.59 | 4.61 |
  |  16 | 4.73 | 4.73 | 5.02 | 5.03 | 4.95 | 4.91 | 4.92 | 4.87 |
  |  32 | 5.39 | 5.36 | 5.44 | 5.46 | 5.27 | 5.23 | 5.21 | 5.17 |
  |  64 | 5.83 | 5.76 | 5.80 | 5.77 | 5.45 | 5.44 | 5.45 | 5.33 |
  | 128 | 6.31 | 6.05 | 6.03 | 6.01 | 5.68 | 5.61 | 5.57 | 5.54 |
  | 256 | 6.57 | 6.23 | 6.20 | 6.17 | 5.85 | 5.74 | 5.72 | 5.66 |

Post-PR peaks at **6.57 TB/s = 82 % of B200 peak DRAM** (T=1 B=256
production decode shape).

  ### Split-pool

  With wide_vec now supporting split-pool natively, split-pool
  matches single-pool to within ±1 % at every measured shape.

  ## Tests

  > **513 passed, 0 failed in 18m18s** on B200.

Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP
tests, 12 new OOB regression tests covering `pool_size_multiplier ∈ {1,
4}` × `B ∈ {1, 8, 32}` × `T ∈ {2, 4}`, 12 wrapper-level split-pool
tests.

  ## Files changed (4)

- `flashinfer/gdn_decode.py` — wrapper auto-promotes BF16 non-pool →
pool
- `flashinfer/gdn_kernels/gdn_decode_bf16_state.py` — wide_vec inlined;
dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE
- `tests/gdn/test_decode_delta_rule.py` — split-pool + OOB regression
tests
  - `benchmarks/bench_gdn_decode.py` — `--pool-mode {single,split}` flag

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added `--pool-mode` option to benchmark tool for configuring state
pool allocation (`single` or `split` modes).

* **Tests**
* Expanded BF16 test coverage with regression tests for split-pool
semantics and out-of-bounds scenarios; improved batch-dimension handling
for intermediate-state comparisons.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants