perf(gdn): fix bf16_state T=1 per-call overhead and add pool+padding …#3118
perf(gdn): fix bf16_state T=1 per-call overhead and add pool+padding …#3118ameynaik-hub wants to merge 4 commits intoflashinfer-ai:mainfrom
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR routes the BF16 T==1 GDN decode fast path through the pooled-state BF16 kernel, threads optional Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Host as gated_delta_rule (host)
participant Cache
participant Kernel as BF16 ILP/MTP Kernel
participant Pool as PooledStateMemory
Caller->>Host: call gated_delta_rule(..., initial_state, use_pool, initial_state_indices?, output_state_indices?, output?)
Host->>Cache: lookup/compile callable & default indices/output
Cache-->>Host: callable + default tensors
Host->>Kernel: invoke(callable, state_buf, h0_indices, h0_out_indices, output, config)
Kernel->>Pool: read from pool using h0_indices (negative->0)
Kernel->>Pool: write updated state using h0_out_indices (separate offsets)
Kernel-->>Host: output tensor
Host-->>Caller: return output (and optionally updated state)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces pool-based state management for GDN decode kernels, enabling indexed state access and updates for both ILP and MTP paths. It also optimizes the execution path by caching GPU architecture capabilities and reusing default buffers within the kernel compilation cache. The review feedback identifies a critical bug where sharing a default_output buffer across calls leads to unintended data overwrites, violating standard functional API expectations. Additionally, the reviewer pointed out safety concerns regarding hardcoded device indexing for architecture detection in multi-GPU environments and missing type casting for initial_state_indices which could lead to kernel crashes.
| if output is None: | ||
| output = cache["output"] |
There was a problem hiding this comment.
Caching and returning a shared default_output buffer is a critical bug for a functional API. Subsequent calls to gated_delta_rule with the same shape will overwrite the results of previous calls if the user does not provide an explicit output tensor. This violates standard PyTorch expectations where a new tensor is returned for each call.
To reduce Python overhead without this side effect, consider allocating a new tensor in the steady state or using a dedicated workspace/pool object if this optimization is strictly necessary.
There was a problem hiding this comment.
@ameynaik-hub can you take a look at the AI review bot comments here?
| _GPU_MAJOR, _ = torch.cuda.get_device_capability(0) | ||
| _USE_PACKED_FMA = _GPU_MAJOR >= 10 |
There was a problem hiding this comment.
Hardcoding the GPU architecture detection to device(0) at the module level is unsafe in multi-GPU environments where different GPU generations (e.g., SM90 and SM100) might coexist. If device(0) is SM100 but the kernel is called on an SM90 device, it will attempt to use packed FMA instructions that are not supported, leading to a crash.
Consider caching the capability in a dictionary mapping device IDs to their respective capabilities to maintain performance while ensuring correctness across multiple devices.
There was a problem hiding this comment.
@ameynaik-hub can you take a look here? This might also have to be torch.cuda.current_device().
| use_pool = initial_state_indices is not None | ||
| if output_state_indices is not None and output_state_indices.dtype != torch.int32: | ||
| output_state_indices = output_state_indices.to(torch.int32) |
There was a problem hiding this comment.
While output_state_indices is converted to int32, initial_state_indices is not. If a user passes int64 indices (the default for torch.arange or long tensors), it may cause a crash or undefined behavior within the CuTe kernel which expects 32-bit integers.
| use_pool = initial_state_indices is not None | |
| if output_state_indices is not None and output_state_indices.dtype != torch.int32: | |
| output_state_indices = output_state_indices.to(torch.int32) | |
| use_pool = initial_state_indices is not None | |
| if use_pool and initial_state_indices.dtype != torch.int32: | |
| initial_state_indices = initial_state_indices.to(torch.int32) | |
| if output_state_indices is not None and output_state_indices.dtype != torch.int32: | |
| output_state_indices = output_state_indices.to(torch.int32) |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/gdn_decode.py (2)
211-238:⚠️ Potential issue | 🟠 MajorValidate and normalize
initial_state_indicesbefore BF16 dispatch.Line 274 now forwards
initial_state_indicesinto the BF16 T=1 backend, but the pool path only validatesinitial_state. A wrong-shape, CPU, or non-int32read-index tensor can reach the kernel unchecked.Suggested validation
if output_state_indices is not None: assert use_pool, ( "output_state_indices can only be used with initial_state (pool mode)" ) @@ assert output_state_indices.dtype in (torch.int32, torch.int64), ( f"output_state_indices must be int32 or int64, " f"got {output_state_indices.dtype}" ) + assert output_state_indices.device == q.device, ( + f"output_state_indices must be on {q.device}, " + f"got {output_state_indices.device}" + ) + if output_state_indices.dtype != torch.int32: + output_state_indices = output_state_indices.to(torch.int32) if use_pool: + assert initial_state_indices is not None + assert initial_state_indices.shape == (B,), ( + f"Expected initial_state_indices shape [{B}], " + f"got {initial_state_indices.shape}" + ) + assert initial_state_indices.dtype in (torch.int32, torch.int64), ( + f"initial_state_indices must be int32 or int64, " + f"got {initial_state_indices.dtype}" + ) + assert initial_state_indices.device == q.device, ( + f"initial_state_indices must be on {q.device}, " + f"got {initial_state_indices.device}" + ) + if initial_state_indices.dtype != torch.int32: + initial_state_indices = initial_state_indices.to(torch.int32) pool_size = initial_state.shape[0]Also applies to: 274-276
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 211 - 238, The code fails to validate/normalize initial_state_indices before dispatching to the BF16 T=1 backend: ensure that when use_pool is True you validate initial_state_indices (shape == (B,), dtype int32 or int64, device matches initial_state/device, and is contiguous) and then normalize/convert it to the kernel-expected type/device (e.g., move to the same device as initial_state and convert to torch.int32 or torch.int64 as required) before passing it into the BF16 backend; update the logic around initial_state_indices and the BF16 dispatch (the variables initial_state_indices, use_pool and the BF16 T=1 backend call) so a malformed, CPU, or wrong-dtype index tensor cannot reach the kernel.
264-279:⚠️ Potential issue | 🟠 MajorThread compatible caller-provided
outputinto the BF16 backend.The backend now accepts
output, but this wrapper still writes into the backend’s cached default buffer and copies afterward. That preserves an avoidable scratch-buffer race for concurrent calls and adds the copy back on the hot path.Suggested forwarding pattern
assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" scale_val = K**-0.5 if scale is None else scale + backend_output = ( + output + if output is not None + and output.shape == (B, T, HV, V) + and output.device == q.device + and output.dtype == q.dtype + else None + ) if T == 1: @@ output_state_indices=output_state_indices, + output=backend_output, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale_val, ) else: @@ output_state_indices=output_state_indices, + output=backend_output, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale_val, )Also applies to: 282-297
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 264 - 279, The wrapper is calling _gated_delta_rule_bf16_state but still writes into the backend’s cached default buffer and copies out, causing a race and extra copy; update the call sites in gdn_decode.py (the current _gated_delta_rule_bf16_state invocation and the similar invocation later) to accept a caller-provided output buffer and forward that buffer into the backend’s output parameter instead of relying on the backend’s cached default, removing the post-call copy; ensure you thread the same caller-provided output variable through the wrapper signature and pass it into the backend call (preserve existing args like initial_state_source, initial_state_indices, output_state_indices, use_qk_l2norm_in_kernel, scale) so the backend writes directly into the caller’s buffer.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2474-2536: The code stores reusable tensors in
_compiled_kernels_ilp[cache_key] (keys "output" and "default_indices") and then
returns those directly when output or initial_state_indices/output_state_indices
are None, which exposes a mutable cached buffer; change the assignments so
callers get fresh tensors: when setting output = cache["output"] replace with a
safe copy (e.g., output = cache["output"].clone().detach() or allocate a new
empty tensor and copy_into it) and likewise set
initial_state_indices/output_state_indices = cache["default_indices"].clone()
(or .to(...).clone() if dtype/device adjustments are needed); update the code
paths around cache_key, cache["output"], cache["default_indices"],
initial_state_indices, output_state_indices, and output to return clones instead
of the cached objects.
- Around line 2419-2421: Normalize initial_state_indices to torch.int32 the same
way output_state_indices is normalized: when initial_state_indices is not None
and its dtype is not torch.int32, call .to(torch.int32) to convert it before use
(look for the variable initial_state_indices near the existing
output_state_indices conversion and the use_pool logic). Apply the same change
in both occurrences (around the blocks referencing initial_state_indices at the
current location and again near lines ~2639-2640) so the CuTe kernel always
receives int32 tensors.
- Around line 2445-2448: The reshaped pooled state h0_source (created from
initial_state_source.reshape(pool_size * HV, V, K)) can produce a non-contiguous
copy so kernel updates may not reach the caller; before passing h0_source into
from_dlpack()/the kernel ensure it is contiguous by calling .contiguous() when
needed (mirror the intermediate_states pattern), i.e., after computing h0_source
replace or assign it with a contiguous tensor if h0_source.is_contiguous() is
False so the dispatched kernel updates the original pooled buffer; apply the
same contiguity check/fix in gated_delta_rule_mtp() where the identical reshape
occurs.
---
Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 211-238: The code fails to validate/normalize
initial_state_indices before dispatching to the BF16 T=1 backend: ensure that
when use_pool is True you validate initial_state_indices (shape == (B,), dtype
int32 or int64, device matches initial_state/device, and is contiguous) and then
normalize/convert it to the kernel-expected type/device (e.g., move to the same
device as initial_state and convert to torch.int32 or torch.int64 as required)
before passing it into the BF16 backend; update the logic around
initial_state_indices and the BF16 dispatch (the variables
initial_state_indices, use_pool and the BF16 T=1 backend call) so a malformed,
CPU, or wrong-dtype index tensor cannot reach the kernel.
- Around line 264-279: The wrapper is calling _gated_delta_rule_bf16_state but
still writes into the backend’s cached default buffer and copies out, causing a
race and extra copy; update the call sites in gdn_decode.py (the current
_gated_delta_rule_bf16_state invocation and the similar invocation later) to
accept a caller-provided output buffer and forward that buffer into the
backend’s output parameter instead of relying on the backend’s cached default,
removing the post-call copy; ensure you thread the same caller-provided output
variable through the wrapper signature and pass it into the backend call
(preserve existing args like initial_state_source, initial_state_indices,
output_state_indices, use_qk_l2norm_in_kernel, scale) so the backend writes
directly into the caller’s buffer.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 94e11a71-62c6-424c-a912-8bce379e813f
📥 Commits
Reviewing files that changed from the base of the PR and between 8559397 and 210eb20b0707e1c009e1c2e0362be8cc2c77da73.
📒 Files selected for processing (2)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.py
| use_pool = initial_state_indices is not None | ||
| if output_state_indices is not None and output_state_indices.dtype != torch.int32: | ||
| output_state_indices = output_state_indices.to(torch.int32) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and examine the file
fd "gdn_decode_bf16_state.py" -t fRepository: flashinfer-ai/flashinfer
Length of output: 115
🏁 Script executed:
# Check file size and read the relevant sections
wc -l flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 120
🏁 Script executed:
# Read around lines 2419-2421 to verify the snippet
sed -n '2415,2430p' flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 702
🏁 Script executed:
# Read around lines 2639-2640 (the "Also applies to" location)
sed -n '2635,2650p' flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 685
🏁 Script executed:
# Search for function definitions and initial_state_indices usage around these lines
sed -n '2400,2450p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 1945
🏁 Script executed:
# Find the function definition that contains lines 2419-2421
sed -n '2380,2420p' flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 1506
🏁 Script executed:
# Search for function signature and docstring
sed -n '2300,2450p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py | grep -A 100 "def " | head -150Repository: flashinfer-ai/flashinfer
Length of output: 4154
🏁 Script executed:
# Look for function signature around line 2400
rg "def.*initial_state_indices" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -A 50 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check what function is at line 2419 by looking at function definitions
sed -n '2350,2445p' flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 3548
🏁 Script executed:
# Check the secondary location at lines 2639-2640
sed -n '2625,2655p' flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 1323
🏁 Script executed:
# Search for gated_delta_rule_mtp function signature
rg "def gated_delta_rule_mtp" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -A 40Repository: flashinfer-ai/flashinfer
Length of output: 1717
🏁 Script executed:
# Look for kernel calls and dtype conversions in the file
rg "\.to\(torch\.int32\)" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 534
🏁 Script executed:
# Search for CuTe or CUDA kernel bindings to see dtype requirements
rg "initial_state_indices" flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 907
🏁 Script executed:
# Look for kernel launch with these indices to see if dtype matters
rg "_compiled_kernels_ilp\[" flashinfer/gdn_kernels/gdn_decode_bf16_state.py -A 5 -B 5 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 888
🏁 Script executed:
# Check if there are any other dtype conversions for initial_state_indices
rg "initial_state_indices.*int32|initial_state_indices.*dtype" flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 213
🏁 Script executed:
# Look at the gated_delta_rule_mtp function body to see if it converts initial_state_indices
sed -n '2540,2700p' flashinfer/gdn_kernels/gdn_decode_bf16_state.py | grep -A 30 "def gated_delta_rule_mtp" | tail -40Repository: flashinfer-ai/flashinfer
Length of output: 1152
🏁 Script executed:
# Get the full gated_delta_rule_mtp body around the indices handling
sed -n '2540,2780p' flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 8087
Normalize initial_state_indices to int32 like output_state_indices.
Both functions accept initial_state_indices documented as int32 in their docstrings, but only normalize output_state_indices when passed with a different dtype. If callers provide int64 tensors for initial_state_indices, they reach the CuTe kernel with a mismatched dtype. This should be handled consistently:
Suggested fix
use_pool = initial_state_indices is not None
+ if initial_state_indices is not None and initial_state_indices.dtype != torch.int32:
+ initial_state_indices = initial_state_indices.to(torch.int32)
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)Also applies to: 2639-2640
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2419 - 2421,
Normalize initial_state_indices to torch.int32 the same way output_state_indices
is normalized: when initial_state_indices is not None and its dtype is not
torch.int32, call .to(torch.int32) to convert it before use (look for the
variable initial_state_indices near the existing output_state_indices conversion
and the use_pool logic). Apply the same change in both occurrences (around the
blocks referencing initial_state_indices at the current location and again near
lines ~2639-2640) so the CuTe kernel always receives int32 tensors.
| # Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K]; | ||
| # pool [pool_size, HV, V, K] -> [pool_size*HV, V, K]. | ||
| pool_size = initial_state_source.shape[0] if use_pool else B | ||
| h0_source = initial_state_source.reshape(pool_size * HV, V, K) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp/repo && find . -name "gdn_decode_bf16_state.py" -type fRepository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
wc -l /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 156
🏁 Script executed:
sed -n '2440,2455p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 168
🏁 Script executed:
sed -n '2635,2650p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 168
🏁 Script executed:
sed -n '2400,2500p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 168
🏁 Script executed:
sed -n '2600,2700p' /tmp/repo/flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 168
🏁 Script executed:
find . -name "gdn_decode_bf16_state.py" -type fRepository: flashinfer-ai/flashinfer
Length of output: 117
🏁 Script executed:
git ls-files | grep gdn_decode_bf16_stateRepository: flashinfer-ai/flashinfer
Length of output: 115
🏁 Script executed:
git ls-files | grep -i gdnRepository: flashinfer-ai/flashinfer
Length of output: 902
🏁 Script executed:
wc -l ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 122
🏁 Script executed:
sed -n '2440,2455p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 681
🏁 Script executed:
sed -n '2635,2650p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 685
🏁 Script executed:
sed -n '2400,2500p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 4131
🏁 Script executed:
sed -n '2350,2420p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 2597
🏁 Script executed:
grep -n "def gated_delta_rule" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -5Repository: flashinfer-ai/flashinfer
Length of output: 125
🏁 Script executed:
sed -n '2380,2450p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 2716
🏁 Script executed:
rg "initial_state_source" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 3829
🏁 Script executed:
sed -n '2445,2520p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 3055
🏁 Script executed:
sed -n '2642,2700p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 2386
🏁 Script executed:
grep -n "from_dlpack\|h0_source" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1710
🏁 Script executed:
rg "gated_delta_rule" ./flashinfer/gdn_kernels/ --type py -lRepository: flashinfer-ai/flashinfer
Length of output: 251
🏁 Script executed:
cat ./tests/gdn/test_decode_pretranspose_noncontiguous_pool.pyRepository: flashinfer-ai/flashinfer
Length of output: 3739
🏁 Script executed:
rg "is_contiguous" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 119
🏁 Script executed:
sed -n '2505,2550p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 1168
🏁 Script executed:
sed -n '1,50p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 1550
🏁 Script executed:
rg "gated_delta_rule_decode_pretranspose" ./flashinfer/gdn_decode.py -A 10 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 694
🏁 Script executed:
rg "def gated_delta_rule_decode_pretranspose" ./flashinfer -r '$0' --type pyRepository: flashinfer-ai/flashinfer
Length of output: 136
🏁 Script executed:
grep -n "def gated_delta_rule" ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 176
🏁 Script executed:
sed -n '110,200p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 4988
🏁 Script executed:
sed -n '403,500p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 4085
🏁 Script executed:
sed -n '544,650p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 4581
🏁 Script executed:
python3 << 'EOF'
import torch
# Simulate the test pattern
page_gap = 2
pool_size = 8
HV = 32
V = 128
K = 128
# Create non-contiguous pool like test does
pool_storage = torch.randn(pool_size, page_gap, HV, V, K, dtype=torch.float32)
pool_source = pool_storage[:, page_gap - 1]
print(f"pool_source is_contiguous: {pool_source.is_contiguous()}")
print(f"pool_source shape: {pool_source.shape}")
print(f"pool_source strides: {pool_source.strides}")
# Try to reshape like the kernel does
h0_source = pool_source.reshape(pool_size * HV, V, K)
print(f"h0_source is_contiguous: {h0_source.is_contiguous()}")
# Check if it's a view or a copy
print(f"pool_source data_ptr: {pool_source.data_ptr()}")
print(f"h0_source data_ptr: {h0_source.data_ptr()}")
print(f"Is h0_source a view of pool_source? {h0_source.data_ptr() == pool_source.data_ptr()}")
# Check if modifying h0_source affects pool_source
h0_source[0, 0, 0] = 999.0
print(f"pool_source[0, 0, 0, 0] after modification: {pool_source[0, 0, 0, 0]}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 185
🏁 Script executed:
cd ./tests/gdn && python -m pytest test_decode_pretranspose_noncontiguous_pool.py -v 2>&1 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
sed -n '2660,2675p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 430
🏁 Script executed:
sed -n '2505,2545p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 1064
🏁 Script executed:
sed -n '200,250p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 2306
🏁 Script executed:
rg "initial_state.is_contiguous\|initial_state_source.is_contiguous" ./flashinfer/gdn_kernels/Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '170,210p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 2411
🏁 Script executed:
sed -n '2364,2380p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 695
🏁 Script executed:
sed -n '225,245p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 945
🏁 Script executed:
sed -n '260,310p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 2155
🏁 Script executed:
grep -n "def gated_delta_rule\|def gated_delta_rule_mtp" ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 176
🏁 Script executed:
sed -n '250,280p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 1303
🏁 Script executed:
sed -n '2380,2410p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 1141
🏁 Script executed:
sed -n '2600,2625p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 1064
🏁 Script executed:
rg "contiguous\|is_contiguous" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "stride\(-1\)" ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 115
🏁 Script executed:
sed -n '2658,2670p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 451
🏁 Script executed:
rg "\.reshape\(" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py | grep -n "source"Repository: flashinfer-ai/flashinfer
Length of output: 205
🏁 Script executed:
sed -n '2620,2680p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 2026
🏁 Script executed:
sed -n '238,242p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 348
🏁 Script executed:
rg "initial_state_source.*contiguous\|initial_state_source.*reshape" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '2640,2680p' ./flashinfer/gdn_kernels/gdn_decode_bf16_state.pyRepository: flashinfer-ai/flashinfer
Length of output: 1424
🏁 Script executed:
sed -n '235,245p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 571
🏁 Script executed:
python3 << 'EOF'
# Analyze whether K-contiguity guarantee prevents copy on reshape
# For a tensor of shape [pool_size, HV, V, K] with stride(-1) == 1,
# when we reshape to [pool_size*HV, V, K], PyTorch checks if the
# new shape is compatible with the old strides.
# K-contiguous means stride[-1] == 1 (elements in K are adjacent)
# But elements along other dimensions may not be adjacent.
# Example: [8, 32, 128, 128] with page stride on dim 0
# Strides might be: (4194304, 131072, 1024, 1) [if dim 1 has page gap]
# Reshaping to [256, 128, 128] requires view to be compatible
# Since dim 0 and 1 don't form a contiguous block, reshape may copy
print("K-contiguity (stride[-1]==1) is NOT sufficient to avoid copies on reshape")
print("PyTorch reshape needs compatible strides for all dimensions being merged")
print("When merging pool_size and HV dimensions, non-contiguous strides can force copy")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 294
🏁 Script executed:
rg "contiguous" ./flashinfer/gdn_kernels/gdn_decode_bf16_state.py -nRepository: flashinfer-ai/flashinfer
Length of output: 196
🏁 Script executed:
sed -n '280,310p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 1296
🏁 Script executed:
sed -n '255,275p' ./flashinfer/gdn_decode.pyRepository: flashinfer-ai/flashinfer
Length of output: 996
Add contiguity check for reshaped pooled state before kernel dispatch.
The reshape(pool_size * HV, V, K) operation can silently allocate a copy when initial_state_source has non-contiguous strides across the pool_size and HV dimensions, even though the public API enforces only K-contiguity (stride[-1] == 1). The kernel then updates the copy and the caller's pool remains unchanged, violating the documented in-place update contract.
Adopt the pattern already applied to intermediate_states (lines 2657–2658): explicitly check and make contiguous before passing to from_dlpack().
# Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K];
# pool [pool_size, HV, V, K] -> [pool_size*HV, V, K].
pool_size = initial_state_source.shape[0] if use_pool else B
h0_source = initial_state_source.reshape(pool_size * HV, V, K)
+ if not h0_source.is_contiguous():
+ h0_source = h0_source.contiguous()Apply the same fix at line 2642–2643 in gated_delta_rule_mtp().
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K]; | |
| # pool [pool_size, HV, V, K] -> [pool_size*HV, V, K]. | |
| pool_size = initial_state_source.shape[0] if use_pool else B | |
| h0_source = initial_state_source.reshape(pool_size * HV, V, K) | |
| # Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K]; | |
| # pool [pool_size, HV, V, K] -> [pool_size*HV, V, K]. | |
| pool_size = initial_state_source.shape[0] if use_pool else B | |
| h0_source = initial_state_source.reshape(pool_size * HV, V, K) | |
| if not h0_source.is_contiguous(): | |
| h0_source = h0_source.contiguous() |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2445 - 2448,
The reshaped pooled state h0_source (created from
initial_state_source.reshape(pool_size * HV, V, K)) can produce a non-contiguous
copy so kernel updates may not reach the caller; before passing h0_source into
from_dlpack()/the kernel ensure it is contiguous by calling .contiguous() when
needed (mirror the intermediate_states pattern), i.e., after computing h0_source
replace or assign it with a contiguous tensor if h0_source.is_contiguous() is
False so the dispatched kernel updates the original pooled buffer; apply the
same contiguity check/fix in gated_delta_rule_mtp() where the identical reshape
occurs.
| default_output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) | ||
| default_indices = torch.arange(B, dtype=torch.int32, device=q.device) | ||
|
|
||
| q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) | ||
| k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) | ||
| v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) | ||
| a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) | ||
| b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) | ||
| A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) | ||
| dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) | ||
| h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) | ||
| o_ = from_dlpack(default_output, assumed_align=32, enable_tvm_ffi=True) | ||
| h0_idx_ = from_dlpack(default_indices, assumed_align=32, enable_tvm_ffi=True) | ||
| h0_out_idx_ = from_dlpack( | ||
| default_indices, assumed_align=32, enable_tvm_ffi=True | ||
| ) | ||
|
|
||
| # Use maxrregcount=64 for smaller tile_v to improve occupancy | ||
| # when grid size is small (fewer waves) | ||
| if tile_v < 128: | ||
| compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3 --ptxas-options=-maxrregcount=64" | ||
| else: | ||
| compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3" | ||
| _compiled_kernels_ilp[cache_key] = cute.compile( | ||
| run_gdn_decode_bf16state_ilp, | ||
| h_, | ||
| A_log_, | ||
| a_, | ||
| dt_bias_, | ||
| q_, | ||
| k_, | ||
| v_, | ||
| b_, | ||
| o_, | ||
| softplus_beta, | ||
| softplus_threshold, | ||
| scale, | ||
| HV, | ||
| B, | ||
| H, | ||
| K, | ||
| V, | ||
| use_qk_l2norm_in_kernel, | ||
| use_packed_fma, | ||
| tile_v, | ||
| stream, | ||
| options=compile_opts, | ||
| ) | ||
| _compiled_kernels_ilp[cache_key] = { | ||
| "compiled": cute.compile( | ||
| run_gdn_decode_bf16state_ilp, | ||
| h_, | ||
| A_log_, | ||
| a_, | ||
| dt_bias_, | ||
| q_, | ||
| k_, | ||
| v_, | ||
| b_, | ||
| o_, | ||
| h0_idx_, | ||
| h0_out_idx_, | ||
| softplus_beta, | ||
| softplus_threshold, | ||
| scale, | ||
| HV, | ||
| B, | ||
| H, | ||
| K, | ||
| V, | ||
| use_qk_l2norm_in_kernel, | ||
| use_packed_fma, | ||
| tile_v, | ||
| stream, | ||
| options=compile_opts, | ||
| ), | ||
| "output": default_output, | ||
| "default_indices": default_indices, | ||
| } | ||
|
|
||
| cache = _compiled_kernels_ilp[cache_key] | ||
|
|
||
| if initial_state_indices is None: | ||
| initial_state_indices = cache["default_indices"] | ||
| if output_state_indices is None: | ||
| output_state_indices = initial_state_indices | ||
| if output is None: | ||
| output = cache["output"] |
There was a problem hiding this comment.
Avoid returning a cached mutable output buffer.
When output is None, these functions return cache["output"]. A later same-shape call overwrites the previous result tensor, so callers that keep the returned output can observe it mutate unexpectedly.
Safer output handling
- _compiled_kernels_ilp[cache_key] = {
+ _compiled_kernels_ilp[cache_key] = {
"compiled": cute.compile(
@@
),
- "output": default_output,
"default_indices": default_indices,
}
@@
if output is None:
- output = cache["output"]
+ output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)- _compiled_kernels_mtp[cache_key] = {
+ _compiled_kernels_mtp[cache_key] = {
"compiled": cute.compile(
@@
),
"default_indices": default_indices,
- "output": default_output,
}
@@
if output is None:
- output = cache["output"]
+ output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)Also applies to: 2693-2754
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2474 - 2536,
The code stores reusable tensors in _compiled_kernels_ilp[cache_key] (keys
"output" and "default_indices") and then returns those directly when output or
initial_state_indices/output_state_indices are None, which exposes a mutable
cached buffer; change the assignments so callers get fresh tensors: when setting
output = cache["output"] replace with a safe copy (e.g., output =
cache["output"].clone().detach() or allocate a new empty tensor and copy_into
it) and likewise set initial_state_indices/output_state_indices =
cache["default_indices"].clone() (or .to(...).clone() if dtype/device
adjustments are needed); update the code paths around cache_key,
cache["output"], cache["default_indices"], initial_state_indices,
output_state_indices, and output to return clones instead of the cached objects.
…to ILP kernel
Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:
1) Per-call Python overhead fix
Move torch.arange / torch.empty / from_dlpack / get_device_capability out
of the steady-state call path. These were previously called on every
invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
of CUPTI-visible overhead per call at small BS. All default-tensor
allocation and dlpack conversion is now done once, inside the
`cache_key not in _compiled_kernels_*` block, and cached alongside the
compiled kernel. Steady-state calls pass raw torch tensors directly to
the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
per-call torch.cuda.get_device_capability().
2) Pool + padding support on the ILP kernel
gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
accepts h0_indices and h0_out_indices, matching the MTP kernel's
signature. Negative indices redirect to pool slot 0 (null buffer);
writes go to a separate flat_write_idx so input and output pool slots
can differ. The ILP launcher and gated_delta_rule wrapper thread the
new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
collapsed so pool+indices T=1 calls no longer detour through the
heavier MTP kernel.
Design choice: kernel always takes indices (no constexpr switch).
Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200
Command:
python benchmarks/bench_gdn_decode.py \
--batch-size 1 4 8 16 32 64 128 256 512 \
--num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
--head-size 128 --dtype bfloat16 --warmup 20 --iters 200
Bf16State column results (us):
BS | time
1 | 3.71
4 | 5.89
8 | 9.18
16 | 14.98
32 | 26.56
64 | 48.24
128 | 89.66
256 | 172.35
512 | 337.60
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
210eb20 to
31288e5
Compare
|
/bot run |
|
Mirroring Failed Failed to mirror PR to GitLab. Check logs for details. |
|
/bot run |
|
/bot stop |
|
The GitLab CI pipeline #49010651 has been cancelled. |
|
/bot run |
|
/bot run |
kahyunnam
left a comment
There was a problem hiding this comment.
Mostly LGTM, but left a few comments
| @@ -2294,6 +2335,11 @@ | |||
| # Number of SMs on target GPU (detected dynamically) | |||
| NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count | |||
There was a problem hiding this comment.
I realize this is not part of the PR; but its kind of related to the changes below. This also probably should be the current device, not index 0; we can re-use this utils function which also caches: https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/cute_dsl/utils.py#L95-L98
| _GPU_MAJOR, _ = torch.cuda.get_device_capability(0) | ||
| _USE_PACKED_FMA = _GPU_MAJOR >= 10 |
There was a problem hiding this comment.
@ameynaik-hub can you take a look here? This might also have to be torch.cuda.current_device().
| if output is None: | ||
| output = cache["output"] |
There was a problem hiding this comment.
@ameynaik-hub can you take a look at the AI review bot comments here?
| # PUBLIC API | ||
| # ============================================================================== | ||
| _compiled_kernels: dict = {} | ||
| _compiled_kernels_ilp: dict = {} |
There was a problem hiding this comment.
This is not a blocker for this PR, but I'm not 100% following what is going on here; manually managing a cache dictionary seems to do the same thing as @functools.cache.
Is there a reason we're not following the @functools.cache pattern like the other cute dsl backends, like here? https://github.com/flashinfer-ai/flashinfer/blob/main/flashinfer/gdn_kernels/gdn_decode_pretranspose.py#L886-L906
<!-- .github/pull_request_template.md --> ## Summary Replaces the legacy `gdn_decode_bf16state_cooprow_kernel` and the `gdn_decode_bf16state_mtp_kernel` (ILP=8) with a new **`gdn_wide_vec_kernel`** (LDG.E.128 / STG.E.128 fast path) plus a small-batch `mtp_ilp4` fallback. Drops ~1900 LOC of dead/unused code, adds split-pool support (#2905-compatible) to both surviving BF16 kernels, and ships the OOB fix mirroring upstream PR #3145 — for the BF16 kernels that survived the cleanup. **Supersedes #3118.** That PR's perf delta (T=1 per-call overhead + pool+padding for the ILP kernel) is the first commit on this branch (`8a6e9819`). ## What changes - **New kernel**: `gdn_wide_vec_kernel` — 128 threads/CTA = 8 groups × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per thread. Configurable `tile_v ∈ {32, 64, 128}` so the kernel covers small/medium/large `B*HV` work-unit sizes uniformly. - **Pool-only**: BF16 GDN dispatch is strictly pool-mode (matches the production serving contract). Wrapper `gated_delta_rule_decode_pretranspose` auto-promotes legacy non-pool callers internally — public API unchanged. - **Split-pool support** (PR #2905 contract): both surviving BF16 kernels (`gdn_wide_vec_kernel`, `gdn_decode_bf16state_mtp_ilp4_kernel`) natively support `output_state_indices != initial_state_indices`, with bit-equivalent single-pool behavior selected at compile time via `Constexpr[bool] same_pool` for zero-overhead dispatch. - **OOB fix (PR #3145 equivalent)**: `intermediate_states` is indexed by the per-call batch index `i_n` (not the pool-scoped `cache_idx`), so the buffer can be sized `[B, T, HV, V, K]` as production callers expect. Regression test catches the bug; pre-fix triggers `cudaErrorIllegalAddress` in <2 s. ## Removed (~1900 LOC of dead code) | Kernel | Why removed | |---|---| | `gdn_decode_bf16state_cooprow_kernel` (~280 LOC) | Replaced by wide_vec + ILP=4 MTP; had known correctness issues at small batch | | `gdn_decode_bf16state_ilp_kernel` (~740 LOC) | Only reachable at HV<32 with B≥16 — not a Qwen3.5 shape; MTP path covers it | | `gdn_decode_bf16state_mtp_kernel` (ILP=8) (~940 LOC) | After wide_vec extension to split-pool + tile_v=32, mtp_kernel was unreachable | End-state BF16 surface = **2 `@cute.kernel`s in one file**: - `gdn_wide_vec_kernel` — production hot path - `gdn_decode_bf16state_mtp_ilp4_kernel` — small-batch fallback Both pool-only, both split-pool capable, both indexed batch-scoped. ## Speedup vs previous baseline Baseline = pre-wide_vec dispatch (the `mtp_kernel` ILP=8 path, captured on this same branch by monkey-patching `_select_wide_vec_tile_v` to return `None` for every shape). Same harness, same hardware, same config — so the comparison isolates the kernel-level speedup that wide_vec + the cleanup deliver. Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50, T=1 invoked with `--update-state`, T≥2 invoked with `--cache-intermediate-states`. Kernel time in microseconds (CUPTI). ### Speedup (×, baseline / post-PR) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|-------|-------|-------|-------|-------|-------|-------|-------| | 1 | 1.03× | 1.04× | 1.03× | 1.03× | 1.00× | 1.02× | 1.02× | 1.01× | | 4 | 0.97× | 1.23× | 1.10× | 1.12× | 1.11× | 1.11× | 1.14× | 1.14× | | 8 | 1.08× | 1.11× | 1.11× | 1.12× | 1.13× | 1.15× | 1.15× | 1.14× | | 16 | 1.04× | 1.09× | 1.11× | 1.13× | 1.13× | 1.12× | 1.11× | 1.10× | | 32 | 1.06× | 1.12× | 1.10× | 1.11× | 1.10× | 1.09× | 1.09× | 1.09× | | 64 | 1.04× | 1.11× | 1.08× | 1.09× | 1.06× | 1.07× | 1.08× | 1.06× | | 128 | 1.04× | 1.11× | 1.07× | 1.09× | 1.06× | 1.06× | 1.07× | 1.07× | | 256 | 1.04× | 1.11× | 1.07× | 1.09× | 1.07× | 1.06× | 1.07× | 1.07× | ### Time reduction (%) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|-------|--------|-------|--------|--------|--------|--------|--------| | 1 | +2.8% | +3.5% | +3.2% | +3.3% | +0.1% | +1.7% | +1.8% | +0.7% | | 4 | −3.2% | +18.7% | +9.5% | +10.7% | +9.7% | +10.2% | +12.3% | +12.5% | | 8 | +7.8% | +10.1% | +10.0%| +10.7% | +11.8% | +12.7% | +12.8% | +12.4% | | 16 | +3.4% | +8.5% | +10.0%| +11.2% | +11.4% | +10.7% | +10.3% | +9.4% | | 32 | +6.0% | +10.4% | +9.4% | +9.6% | +8.8% | +8.4% | +8.4% | +7.8% | | 64 | +4.0% | +10.3% | +7.5% | +8.6% | +5.9% | +6.3% | +7.2% | +5.9% | | 128 | +4.2% | +9.8% | +6.5% | +8.4% | +6.0% | +5.9% | +6.3% | +6.6% | | 256 | +4.2% | +9.6% | +6.6% | +8.4% | +6.3% | +5.9% | +6.5% | +6.6% | ### Headline - **T=1 production decode (B≥16)**: 4–6 % time reduction across the full batch sweep — the Qwen3.5 hot path. - **T≥2 with cache=ON (B≥4)**: 6–18 % time reduction at every shape. Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×). - **Tiny shapes (B=1)**: within ±3 % of baseline (kernel isn't DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was already efficient there). ### Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|------|------|------|------|------|------|------|------| | 1 | 1.25 | 1.21 | 1.49 | 1.63 | 1.45 | 1.55 | 1.64 | 1.70 | | 4 | 2.83 | 3.24 | 3.54 | 3.78 | 3.53 | 3.68 | 3.84 | 3.95 | | 8 | 3.97 | 4.09 | 4.40 | 4.54 | 4.42 | 4.55 | 4.59 | 4.61 | | 16 | 4.73 | 4.73 | 5.02 | 5.03 | 4.95 | 4.91 | 4.92 | 4.87 | | 32 | 5.39 | 5.36 | 5.44 | 5.46 | 5.27 | 5.23 | 5.21 | 5.17 | | 64 | 5.83 | 5.76 | 5.80 | 5.77 | 5.45 | 5.44 | 5.45 | 5.33 | | 128 | 6.31 | 6.05 | 6.03 | 6.01 | 5.68 | 5.61 | 5.57 | 5.54 | | 256 | 6.57 | 6.23 | 6.20 | 6.17 | 5.85 | 5.74 | 5.72 | 5.66 | Post-PR peaks at **6.57 TB/s = 82 % of B200 peak DRAM** (T=1 B=256 production decode shape). ### Split-pool With wide_vec now supporting split-pool natively, split-pool matches single-pool to within ±1 % at every measured shape. ## Tests > **513 passed, 0 failed in 18m18s** on B200. Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP tests, 12 new OOB regression tests covering `pool_size_multiplier ∈ {1, 4}` × `B ∈ {1, 8, 32}` × `T ∈ {2, 4}`, 12 wrapper-level split-pool tests. ## Files changed (4) - `flashinfer/gdn_decode.py` — wrapper auto-promotes BF16 non-pool → pool - `flashinfer/gdn_kernels/gdn_decode_bf16_state.py` — wide_vec inlined; dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE - `tests/gdn/test_decode_delta_rule.py` — split-pool + OOB regression tests - `benchmarks/bench_gdn_decode.py` — `--pool-mode {single,split}` flag <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--pool-mode` option to benchmark tool for configuring state pool allocation (`single` or `split` modes). * **Tests** * Expanded BF16 test coverage with regression tests for split-pool semantics and out-of-bounds scenarios; improved batch-dimension handling for intermediate-state comparisons. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…to ILP kernel
Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and the T=1 dispatch in flashinfer/gdn_decode.py:
Per-call Python overhead fix
Move torch.arange / torch.empty / from_dlpack / get_device_capability out
of the steady-state call path. These were previously called on every
invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
of CUPTI-visible overhead per call at small BS. All default-tensor
allocation and dlpack conversion is now done once, inside the
cache_key not in _compiled_kernels_*block, and cached alongside thecompiled kernel. Steady-state calls pass raw torch tensors directly to
the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
per-call torch.cuda.get_device_capability().
Pool + padding support on the ILP kernel
gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
accepts h0_indices and h0_out_indices, matching the MTP kernel's
signature. Negative indices redirect to pool slot 0 (null buffer);
writes go to a separate flat_write_idx so input and output pool slots
can differ. The ILP launcher and gated_delta_rule wrapper thread the
new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
collapsed so pool+indices T=1 calls no longer detour through the
heavier MTP kernel.
Design choice: kernel always takes indices (no constexpr switch).
Benchmark: bench_gdn_decode.py (B200, HV=64, T=1, bf16, qk_l2norm ON)
Command:
python benchmarks/bench_gdn_decode.py
--batch-size 1 4 8 16 32 64 128 256 512
--num-q-heads 32 --num-k-heads 32 --num-v-heads 64
--head-size 128 --dtype bfloat16 --warmup 20 --iters 200
Bf16State column results (us):
BS | time
1 | 3.68
4 | 5.82
8 | 9.30
16 | 15.01
32 | 26.86
64 | 48.22
128 | 89.63
256 | 172.74
512 | 337.89
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Performance