feat(kda): add recurrent KDA decode kernel with per-K gating#2572
feat(kda): add recurrent KDA decode kernel with per-K gating#2572djmmoss wants to merge 81 commits intoflashinfer-ai:mainfrom
Conversation
Uses GitHub's mergeUpstream API to keep main branch in sync with flashinfer-ai/flashinfer. Runs every hour and can be triggered manually.
Add KDA (Key-Driven Attention) decode support as a CuTe DSL kernel, extending the GDN decode kernel from PR flashinfer-ai#2498 to support per-key-dimension gating. KDA generalizes GDN's scalar gate (g in R^1) to per-K gating (g in R^K), with the gate mapping naturally to the warp structure. Changes: - Extract shared gate-independent helpers from GDN kernel into flashinfer/gdn_kernels/_common.py (~290 lines), slimming gdn_decode_bf16_state.py. No GDN behavior change. - Add HEAD_DIM=64 support to GDN dispatch (previously 128 only) - Preserve lowBS_1chunk kernel variants for B<=4 (both GDN and KDA) - New flashinfer/kda_kernels/ module with T=1-4 kernels for HEAD_DIM={64,128}, plus chunk_kda-compatible wrapper - 80 KDA tests covering correctness, state updates, GDN reduction - KDA decode benchmark Tested on B200 (SM100) with CUDA 12.9. BF16 storage, FP32 compute. GDN: 138/138 tests pass, no performance regression. KDA: 80/80 tests pass. AI-assisted (Claude Code)
- KDA DLPack cache key: (T, B, HEAD_DIM) -> (T, B, H, HV, HEAD_DIM) to avoid incorrect kernel reuse when H or HV differs - Update stale cache dict comments in both GDN and KDA modules
- bench_kda_decode: fix dtype.itemsize crash (use tensor.element_size())
- bench_kda_decode: prefix unused num_k_heads with underscore
- kda_decode_bf16_state: use importlib.util.find_spec for tvm_ffi check
- kda_decode_bf16_state: rename unused o_head param to _o_head
- kda_decode_bf16_state: assert use_qk_l2norm_in_kernel=True (always on)
- gdn_decode: update docstring K=V=128 -> K=V in {64,128}
- test_decode_kda: fix SM arch check to lower-bound (cc[0] < 9)
- test_decode_kda: rename test_kda_reduces_to_gdn (misleading name)
g, beta, output, state use HV (num_v_heads), not max(q, v). q/k use H (num_q_heads). Fixes skewed TB/s metric for GQA configs.
Reduce ATOL/RTOL from 1e-1/5e-2 to 5e-3/5e-3, matching the GDN decode test tolerances. Measured worst-case errors are ~1.5e-5 absolute and ~7.8e-3 relative, so 5e-3 still provides ample margin.
…GDN changes Replace the KDA decode kernel with the new recurrent_kda implementation that uses TVM FFI compilation and supports cu_seqlens, ssm_state_indices, and in-kernel gate computation (softplus/lower-bound modes). The kernel is now T=1 only. Revert all GDN refactoring that was part of the original PR since it is no longer needed (KDA is fully self-contained).
6bb4b57 to
78c848b
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
♻️ Duplicate comments (3)
flashinfer/kda_kernels/kda_decode_bf16_state.py (1)
1079-1080:⚠️ Potential issue | 🟠 Major
use_qk_l2norm_in_kernelis still silently ignored.The flag is exposed in API/docs but never changes behavior. Please explicitly reject
Falseuntil a non-normalizing path exists.🔧 Minimal safe guard
def recurrent_kda( @@ ) -> torch.Tensor: @@ + if not use_qk_l2norm_in_kernel: + raise NotImplementedError( + "recurrent_kda currently requires use_qk_l2norm_in_kernel=True." + )Also applies to: 1132-1135
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/kda_kernels/kda_decode_bf16_state.py` around lines 1079 - 1080, The parameter use_qk_l2norm_in_kernel is declared but ignored; add an explicit runtime guard that rejects False (e.g., raise ValueError) wherever the flag is accepted so callers cannot silently disable L2-normalization until a non-normalizing kernel path is implemented; update the check in the constructor/function that defines use_qk_l2norm_in_kernel (and mirror the same guard at the other occurrence around lines 1132-1135) so any attempt to pass use_qk_l2norm_in_kernel=False fails fast with a clear error message referencing use_qk_l2norm_in_kernel and the containing function/class name.benchmarks/bench_kda_decode.py (1)
74-99:⚠️ Potential issue | 🟠 MajorUse query-head count for
k_bytesin this benchmark path.Line 146 builds
kwithnum_q_heads, but Line 98 modelsk_byteswithnum_k_heads. That skews reported TB/s whenever the two differ.🔧 Proposed fix
def kda_decode_bytes( batch_size: int, num_q_heads: int, - num_k_heads: int, + _num_k_heads: int, @@ - k_bytes = batch_size * seq_len * num_k_heads * head_size * elem_size + k_bytes = batch_size * seq_len * num_q_heads * head_size * elem_sizeBased on learnings: In
flashinfer/kda_kernels/kda_decode_bf16_state.py,kis indexed withquery_head_idxand expected as[B, T, num_q_heads, K](pre-expanded to query heads).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_kda_decode.py` around lines 74 - 99, In kda_decode_bytes the modeled k_bytes uses num_k_heads but the actual kernel and construction of k use query heads (num_q_heads); change the k_bytes calculation to use num_q_heads instead of num_k_heads so k_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size (update the expression referenced as k_bytes inside kda_decode_bytes).tests/kda/test_decode_kda.py (1)
4-9:⚠️ Potential issue | 🟠 MajorAdd required architecture skip helper from
flashinfer.utils.These tests launch the CuTe KDA kernel but do not gate execution on supported GPU architectures. Please add a
flashinfer.utils-based skip guard and call it at the start of each test.🔧 Proposed fix
import pytest import torch import torch.nn.functional as F from flashinfer.kda_kernels import cutedsl_kda_decode +from flashinfer.utils import get_compute_capability + + +def _skip_if_not_sm100_or_later(): + cc = get_compute_capability(torch.device("cuda")) + if cc[0] < 10: + pytest.skip(f"KDA decode requires SM100+, but got SM{cc[0]}{cc[1]}") @@ def test_cutedsl_vs_naive( @@ ): + _skip_if_not_sm100_or_later() """CuTe DSL kernel matches naive recurrent KDA reference.""" @@ def test_cutedsl_vs_fla( @@ ): + _skip_if_not_sm100_or_later() """CuTe DSL kernel matches fla fused_recurrent_kda.""" @@ def test_vllm_decode( @@ ): + _skip_if_not_sm100_or_later() """vLLM-style decoding: continuous batching with paged state, CuTe DSL vs naive."""As per coding guidelines,
tests/**/*.py: “Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures”.Also applies to: 119-123, 179-183, 240-246
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/kda/test_decode_kda.py` around lines 4 - 9, Import the GPU-architecture helper(s) from flashinfer.utils (for example get_compute_capability and is_sm90a_supported) at the top of the test file and add a guard at the start of each test that runs the CuTe KDA kernel (those invoking cutedsl_kda_decode) which calls pytest.skip when the current GPU architecture is not supported; specifically, for each test function call the helper(s) to detect support and call pytest.skip with a short message if unsupported so the CuTe KDA tests are gated on supported architectures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_kda_decode.py`:
- Around line 220-224: The CLI and runner are still allowing seq lengths
[1,2,3,4] while the kernel only supports T=1, causing noisy ERROR rows; update
the filtering and defaults to enforce seq_len == 1 by restricting valid_seq_lens
to only include 1 (replace the comprehension using args.seq_len with one that
only accepts 1) and update any other occurrences (e.g., the runner/CLI handling
around the second occurrence at the 311-315 region) to default to and only
iterate over seq_len == 1 so only T=1 runs are scheduled.
In `@flashinfer/kda_kernels/kda_decode_bf16_state.py`:
- Line 1139: Remove the unused local H at the q.shape unpack; change the
assignment that currently reads "H, K = q.shape[2], q.shape[3]" so it only
assigns K (e.g., assign K from q.shape[3] or unpack into a throwaway for the
first value) to eliminate the unused-variable F841; locate this in
kda_decode_bf16_state where q.shape is referenced and update the assignment
accordingly.
- Around line 1170-1172: The default output allocation incorrectly uses a
hardcoded second dimension of 1 causing out-of-bounds indexing in cu_seqlens
mode; modify the allocation for output so its shape uses q.shape[1] (the
flattened token axis) when running in cu_seqlens mode (i.e., when cu_seqlens is
present/used), e.g. set the second dimension to (q.shape[1] if cu_seqlens is not
None else 1) so indexing by token_offset/gO is valid; update the allocation site
where output is created (symbols: output, q, cu_seqlens, token_offset, gO)
accordingly.
- Around line 1317-1334: The cu_seqlens branch currently unconditionally creates
out_buf = torch.zeros_like(v) and ignores the caller-provided output buffer;
change that branch to mirror the non-cu_seqlens logic: use the provided output
if not None (validating shape/device/dtype match to v) else allocate a new
buffer (e.g., torch.zeros_like(v)); ensure you still set out_buf to this buffer
and preserve existing behavior for initial_state/state/ssi, and do not change
variable names (cu_seqlens, out_buf, v, output, initial_state, state).
- Around line 1302-1306: Before dispatch, add guards validating head-ratio and
tensor shapes: ensure HV >= H and HV % H == 0 before computing query_head_idx =
value_head_idx // (HV // H), and validate that k, g, beta tensors have expected
head/feature dimensions matching K, V, HV (e.g., check k.ndim and last dims
equal K, g shape aligns with H or HV, beta matches attention heads). Also keep
existing dtype/shape asserts (q, v, K==V, K in (64,128)) but add explicit checks
for value_head_idx bounds and that (HV // H) != 0 to avoid division by zero and
out-of-range indexing in functions/methods computing query_head_idx and using
k/g/beta.
---
Duplicate comments:
In `@benchmarks/bench_kda_decode.py`:
- Around line 74-99: In kda_decode_bytes the modeled k_bytes uses num_k_heads
but the actual kernel and construction of k use query heads (num_q_heads);
change the k_bytes calculation to use num_q_heads instead of num_k_heads so
k_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size (update the
expression referenced as k_bytes inside kda_decode_bytes).
In `@flashinfer/kda_kernels/kda_decode_bf16_state.py`:
- Around line 1079-1080: The parameter use_qk_l2norm_in_kernel is declared but
ignored; add an explicit runtime guard that rejects False (e.g., raise
ValueError) wherever the flag is accepted so callers cannot silently disable
L2-normalization until a non-normalizing kernel path is implemented; update the
check in the constructor/function that defines use_qk_l2norm_in_kernel (and
mirror the same guard at the other occurrence around lines 1132-1135) so any
attempt to pass use_qk_l2norm_in_kernel=False fails fast with a clear error
message referencing use_qk_l2norm_in_kernel and the containing function/class
name.
In `@tests/kda/test_decode_kda.py`:
- Around line 4-9: Import the GPU-architecture helper(s) from flashinfer.utils
(for example get_compute_capability and is_sm90a_supported) at the top of the
test file and add a guard at the start of each test that runs the CuTe KDA
kernel (those invoking cutedsl_kda_decode) which calls pytest.skip when the
current GPU architecture is not supported; specifically, for each test function
call the helper(s) to detect support and call pytest.skip with a short message
if unsupported so the CuTe KDA tests are gated on supported architectures.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
benchmarks/bench_kda_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/kda_kernels/__init__.pyflashinfer/kda_kernels/cutedsl_kda_decode.pyflashinfer/kda_kernels/kda_decode_bf16_state.pytests/gdn/test_decode_delta_rule.pytests/kda/test_decode_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/gdn/test_decode_delta_rule.py
- flashinfer/kda_kernels/cutedsl_kda_decode.py
Rename kda_decode_bf16_state.py -> recurrent_kda.py and merge the separate _compile_tvm_ffi / recurrent_kda / cutedsl_kda_decode layers into a single cutedsl_kda_decode function with internal _dispatch_kernel and _compile_tvm_ffi helpers. Remove unused RecurrentKDAKernel class.
- Tighten test tolerances from 1e-1/5e-2 to 5e-3/5e-3 (yzh119) - Add HV >= H and HV % H == 0 guard before dispatch (CodeRabbit) - Respect caller-provided output buffer in cu_seqlens path (CodeRabbit) - Restrict benchmark CLI to T=1 only (CodeRabbit)
…est_recurrent_kda
The recurrent KDA kernel requires SM100a (Blackwell). Add module-level pytest.skip using is_sm100a_supported() so CI on Hopper/Ampere skips gracefully instead of failing at CuTe DSL compilation.
Module-level pytest.skip causes 0 tests collected, which CI treats as failure. Switch to per-test _skip_if_not_sm100() calls matching the pattern used by GDN tests, so tests are collected and individually skipped on non-SM100 architectures.
Description
Add recurrent KDA (Key-Driven Attention) decode kernel as a CuTe DSL kernel
for SM100 (Blackwell), supporting per-key-dimension gating for T=1 decode.
Algorithm: KDA generalizes GDN's scalar gate (g ∈ R¹) to per-K-dimension
gating (g ∈ R^K). The gate maps naturally to the warp structure — each warp's
32 lanes share the same 32 gate values (gate varies by K, not V).
State update:
S = diag(exp(g)) @ S + beta * k * (v - S^T k)Changes:
flashinfer/kda_kernels/recurrent_kda.py— T=1 kernel for HEAD_DIM={64,128}recurrent_kda()with GQA, cu_seqlens, ssm_state_indices support@functools.cachebenchmarks/bench_recurrent_kda.py) and tests (tests/kda/test_recurrent_kda.py)Key design decisions:
g[B, 1, HV, K](or raw input withuse_gate_in_kernel=True)beta[B, 1, HV]passed in directly[N, HV, V, K]bf16 (K-last) for efficient per-K gate applicationRelated Issues
Extends #2498 (GDN decode CuTe DSL kernel)
Pull Request Checklist
Tests
12/12 passed:
test_recurrent_kda_vs_naive: kernel vs naive reference (B={1,4,16}, HD={64,128})test_recurrent_kda_vs_fla: kernel vs fla fused_recurrent_kdatest_vllm_decode: cu_seqlens + paged state with pre-computed, softplus, and lower-bound gate modes (B={7,16,32}, HD={64,128})Benchmarks
Recurrent KDA Performance (HD=128, H=16, HV=32, B200/SM100)