Skip to content

feat(kda): add recurrent KDA decode kernel with per-K gating#2572

Open
djmmoss wants to merge 81 commits intoflashinfer-ai:mainfrom
djmmoss:kda-decode-cutedsl
Open

feat(kda): add recurrent KDA decode kernel with per-K gating#2572
djmmoss wants to merge 81 commits intoflashinfer-ai:mainfrom
djmmoss:kda-decode-cutedsl

Conversation

@djmmoss
Copy link
Collaborator

@djmmoss djmmoss commented Feb 17, 2026

Description

Add recurrent KDA (Key-Driven Attention) decode kernel as a CuTe DSL kernel
for SM100 (Blackwell), supporting per-key-dimension gating for T=1 decode.

Algorithm: KDA generalizes GDN's scalar gate (g ∈ R¹) to per-K-dimension
gating (g ∈ R^K). The gate maps naturally to the warp structure — each warp's
32 lanes share the same 32 gate values (gate varies by K, not V).

State update: S = diag(exp(g)) @ S + beta * k * (v - S^T k)

Changes:

  • New flashinfer/kda_kernels/recurrent_kda.py — T=1 kernel for HEAD_DIM={64,128}
  • Public API: recurrent_kda() with GQA, cu_seqlens, ssm_state_indices support
  • In-kernel gate modes: pre-computed log-space, softplus, lower_bound * sigmoid
  • TVM FFI compilation with @functools.cache
  • Benchmark (benchmarks/bench_recurrent_kda.py) and tests (tests/kda/test_recurrent_kda.py)

Key design decisions:

  • Pre-computed log-space gates g[B, 1, HV, K] (or raw input with use_gate_in_kernel=True)
  • Pre-sigmoided beta[B, 1, HV] passed in directly
  • State layout [N, HV, V, K] bf16 (K-last) for efficient per-K gate application
  • In-kernel L2 normalization of Q and K
  • cu_seqlens + ssm_state_indices for vLLM-style continuous batching

Related Issues

Extends #2498 (GDN decode CuTe DSL kernel)

Pull Request Checklist

  • Tests added and passing (12/12)
  • Benchmark working
  • No GDN changes (fully self-contained)

Tests

12/12 passed:

  • test_recurrent_kda_vs_naive: kernel vs naive reference (B={1,4,16}, HD={64,128})
  • test_recurrent_kda_vs_fla: kernel vs fla fused_recurrent_kda
  • test_vllm_decode: cu_seqlens + paged state with pre-computed, softplus, and lower-bound gate modes (B={7,16,32}, HD={64,128})

Benchmarks

Recurrent KDA Performance (HD=128, H=16, HV=32, B200/SM100)

Batch Time (μs) TFLOPS TB/s
1 10.24 0.41 0.21
16 14.34 4.68 2.38
64 34.85 7.70 3.91
128 59.39 9.04 4.59

Uses GitHub's mergeUpstream API to keep main branch in sync
with flashinfer-ai/flashinfer. Runs every hour and can be triggered manually.
Add KDA (Key-Driven Attention) decode support as a CuTe DSL kernel,
extending the GDN decode kernel from PR flashinfer-ai#2498 to support per-key-dimension
gating. KDA generalizes GDN's scalar gate (g in R^1) to per-K gating
(g in R^K), with the gate mapping naturally to the warp structure.

Changes:
- Extract shared gate-independent helpers from GDN kernel into
  flashinfer/gdn_kernels/_common.py (~290 lines), slimming
  gdn_decode_bf16_state.py. No GDN behavior change.
- Add HEAD_DIM=64 support to GDN dispatch (previously 128 only)
- Preserve lowBS_1chunk kernel variants for B<=4 (both GDN and KDA)
- New flashinfer/kda_kernels/ module with T=1-4 kernels for
  HEAD_DIM={64,128}, plus chunk_kda-compatible wrapper
- 80 KDA tests covering correctness, state updates, GDN reduction
- KDA decode benchmark

Tested on B200 (SM100) with CUDA 12.9. BF16 storage, FP32 compute.
GDN: 138/138 tests pass, no performance regression.
KDA: 80/80 tests pass.

AI-assisted (Claude Code)
- KDA DLPack cache key: (T, B, HEAD_DIM) -> (T, B, H, HV, HEAD_DIM)
  to avoid incorrect kernel reuse when H or HV differs
- Update stale cache dict comments in both GDN and KDA modules
- bench_kda_decode: fix dtype.itemsize crash (use tensor.element_size())
- bench_kda_decode: prefix unused num_k_heads with underscore
- kda_decode_bf16_state: use importlib.util.find_spec for tvm_ffi check
- kda_decode_bf16_state: rename unused o_head param to _o_head
- kda_decode_bf16_state: assert use_qk_l2norm_in_kernel=True (always on)
- gdn_decode: update docstring K=V=128 -> K=V in {64,128}
- test_decode_kda: fix SM arch check to lower-bound (cc[0] < 9)
- test_decode_kda: rename test_kda_reduces_to_gdn (misleading name)
g, beta, output, state use HV (num_v_heads), not max(q, v).
q/k use H (num_q_heads). Fixes skewed TB/s metric for GQA configs.
Reduce ATOL/RTOL from 1e-1/5e-2 to 5e-3/5e-3, matching the GDN decode
test tolerances. Measured worst-case errors are ~1.5e-5 absolute and
~7.8e-3 relative, so 5e-3 still provides ample margin.
…GDN changes

Replace the KDA decode kernel with the new recurrent_kda implementation
that uses TVM FFI compilation and supports cu_seqlens, ssm_state_indices,
and in-kernel gate computation (softplus/lower-bound modes). The kernel is
now T=1 only. Revert all GDN refactoring that was part of the original PR
since it is no longer needed (KDA is fully self-contained).
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (3)
flashinfer/kda_kernels/kda_decode_bf16_state.py (1)

1079-1080: ⚠️ Potential issue | 🟠 Major

use_qk_l2norm_in_kernel is still silently ignored.

The flag is exposed in API/docs but never changes behavior. Please explicitly reject False until a non-normalizing path exists.

🔧 Minimal safe guard
 def recurrent_kda(
@@
 ) -> torch.Tensor:
@@
+    if not use_qk_l2norm_in_kernel:
+        raise NotImplementedError(
+            "recurrent_kda currently requires use_qk_l2norm_in_kernel=True."
+        )

Also applies to: 1132-1135

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/kda_kernels/kda_decode_bf16_state.py` around lines 1079 - 1080,
The parameter use_qk_l2norm_in_kernel is declared but ignored; add an explicit
runtime guard that rejects False (e.g., raise ValueError) wherever the flag is
accepted so callers cannot silently disable L2-normalization until a
non-normalizing kernel path is implemented; update the check in the
constructor/function that defines use_qk_l2norm_in_kernel (and mirror the same
guard at the other occurrence around lines 1132-1135) so any attempt to pass
use_qk_l2norm_in_kernel=False fails fast with a clear error message referencing
use_qk_l2norm_in_kernel and the containing function/class name.
benchmarks/bench_kda_decode.py (1)

74-99: ⚠️ Potential issue | 🟠 Major

Use query-head count for k_bytes in this benchmark path.

Line 146 builds k with num_q_heads, but Line 98 models k_bytes with num_k_heads. That skews reported TB/s whenever the two differ.

🔧 Proposed fix
 def kda_decode_bytes(
     batch_size: int,
     num_q_heads: int,
-    num_k_heads: int,
+    _num_k_heads: int,
@@
-    k_bytes = batch_size * seq_len * num_k_heads * head_size * elem_size
+    k_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size

Based on learnings: In flashinfer/kda_kernels/kda_decode_bf16_state.py, k is indexed with query_head_idx and expected as [B, T, num_q_heads, K] (pre-expanded to query heads).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_kda_decode.py` around lines 74 - 99, In kda_decode_bytes the
modeled k_bytes uses num_k_heads but the actual kernel and construction of k use
query heads (num_q_heads); change the k_bytes calculation to use num_q_heads
instead of num_k_heads so k_bytes = batch_size * seq_len * num_q_heads *
head_size * elem_size (update the expression referenced as k_bytes inside
kda_decode_bytes).
tests/kda/test_decode_kda.py (1)

4-9: ⚠️ Potential issue | 🟠 Major

Add required architecture skip helper from flashinfer.utils.

These tests launch the CuTe KDA kernel but do not gate execution on supported GPU architectures. Please add a flashinfer.utils-based skip guard and call it at the start of each test.

🔧 Proposed fix
 import pytest
 import torch
 import torch.nn.functional as F
 
 from flashinfer.kda_kernels import cutedsl_kda_decode
+from flashinfer.utils import get_compute_capability
+
+
+def _skip_if_not_sm100_or_later():
+    cc = get_compute_capability(torch.device("cuda"))
+    if cc[0] < 10:
+        pytest.skip(f"KDA decode requires SM100+, but got SM{cc[0]}{cc[1]}")
@@
 def test_cutedsl_vs_naive(
@@
 ):
+    _skip_if_not_sm100_or_later()
     """CuTe DSL kernel matches naive recurrent KDA reference."""
@@
 def test_cutedsl_vs_fla(
@@
 ):
+    _skip_if_not_sm100_or_later()
     """CuTe DSL kernel matches fla fused_recurrent_kda."""
@@
 def test_vllm_decode(
@@
 ):
+    _skip_if_not_sm100_or_later()
     """vLLM-style decoding: continuous batching with paged state, CuTe DSL vs naive."""

As per coding guidelines, tests/**/*.py: “Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures”.

Also applies to: 119-123, 179-183, 240-246

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/kda/test_decode_kda.py` around lines 4 - 9, Import the GPU-architecture
helper(s) from flashinfer.utils (for example get_compute_capability and
is_sm90a_supported) at the top of the test file and add a guard at the start of
each test that runs the CuTe KDA kernel (those invoking cutedsl_kda_decode)
which calls pytest.skip when the current GPU architecture is not supported;
specifically, for each test function call the helper(s) to detect support and
call pytest.skip with a short message if unsupported so the CuTe KDA tests are
gated on supported architectures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_kda_decode.py`:
- Around line 220-224: The CLI and runner are still allowing seq lengths
[1,2,3,4] while the kernel only supports T=1, causing noisy ERROR rows; update
the filtering and defaults to enforce seq_len == 1 by restricting valid_seq_lens
to only include 1 (replace the comprehension using args.seq_len with one that
only accepts 1) and update any other occurrences (e.g., the runner/CLI handling
around the second occurrence at the 311-315 region) to default to and only
iterate over seq_len == 1 so only T=1 runs are scheduled.

In `@flashinfer/kda_kernels/kda_decode_bf16_state.py`:
- Line 1139: Remove the unused local H at the q.shape unpack; change the
assignment that currently reads "H, K = q.shape[2], q.shape[3]" so it only
assigns K (e.g., assign K from q.shape[3] or unpack into a throwaway for the
first value) to eliminate the unused-variable F841; locate this in
kda_decode_bf16_state where q.shape is referenced and update the assignment
accordingly.
- Around line 1170-1172: The default output allocation incorrectly uses a
hardcoded second dimension of 1 causing out-of-bounds indexing in cu_seqlens
mode; modify the allocation for output so its shape uses q.shape[1] (the
flattened token axis) when running in cu_seqlens mode (i.e., when cu_seqlens is
present/used), e.g. set the second dimension to (q.shape[1] if cu_seqlens is not
None else 1) so indexing by token_offset/gO is valid; update the allocation site
where output is created (symbols: output, q, cu_seqlens, token_offset, gO)
accordingly.
- Around line 1317-1334: The cu_seqlens branch currently unconditionally creates
out_buf = torch.zeros_like(v) and ignores the caller-provided output buffer;
change that branch to mirror the non-cu_seqlens logic: use the provided output
if not None (validating shape/device/dtype match to v) else allocate a new
buffer (e.g., torch.zeros_like(v)); ensure you still set out_buf to this buffer
and preserve existing behavior for initial_state/state/ssi, and do not change
variable names (cu_seqlens, out_buf, v, output, initial_state, state).
- Around line 1302-1306: Before dispatch, add guards validating head-ratio and
tensor shapes: ensure HV >= H and HV % H == 0 before computing query_head_idx =
value_head_idx // (HV // H), and validate that k, g, beta tensors have expected
head/feature dimensions matching K, V, HV (e.g., check k.ndim and last dims
equal K, g shape aligns with H or HV, beta matches attention heads). Also keep
existing dtype/shape asserts (q, v, K==V, K in (64,128)) but add explicit checks
for value_head_idx bounds and that (HV // H) != 0 to avoid division by zero and
out-of-range indexing in functions/methods computing query_head_idx and using
k/g/beta.

---

Duplicate comments:
In `@benchmarks/bench_kda_decode.py`:
- Around line 74-99: In kda_decode_bytes the modeled k_bytes uses num_k_heads
but the actual kernel and construction of k use query heads (num_q_heads);
change the k_bytes calculation to use num_q_heads instead of num_k_heads so
k_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size (update the
expression referenced as k_bytes inside kda_decode_bytes).

In `@flashinfer/kda_kernels/kda_decode_bf16_state.py`:
- Around line 1079-1080: The parameter use_qk_l2norm_in_kernel is declared but
ignored; add an explicit runtime guard that rejects False (e.g., raise
ValueError) wherever the flag is accepted so callers cannot silently disable
L2-normalization until a non-normalizing kernel path is implemented; update the
check in the constructor/function that defines use_qk_l2norm_in_kernel (and
mirror the same guard at the other occurrence around lines 1132-1135) so any
attempt to pass use_qk_l2norm_in_kernel=False fails fast with a clear error
message referencing use_qk_l2norm_in_kernel and the containing function/class
name.

In `@tests/kda/test_decode_kda.py`:
- Around line 4-9: Import the GPU-architecture helper(s) from flashinfer.utils
(for example get_compute_capability and is_sm90a_supported) at the top of the
test file and add a guard at the start of each test that runs the CuTe KDA
kernel (those invoking cutedsl_kda_decode) which calls pytest.skip when the
current GPU architecture is not supported; specifically, for each test function
call the helper(s) to detect support and call pytest.skip with a short message
if unsupported so the CuTe KDA tests are gated on supported architectures.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 77ed0eb and 6bb4b57.

📒 Files selected for processing (7)
  • benchmarks/bench_kda_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/kda_kernels/__init__.py
  • flashinfer/kda_kernels/cutedsl_kda_decode.py
  • flashinfer/kda_kernels/kda_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py
  • tests/kda/test_decode_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/gdn/test_decode_delta_rule.py
  • flashinfer/kda_kernels/cutedsl_kda_decode.py

Rename kda_decode_bf16_state.py -> recurrent_kda.py and merge the
separate _compile_tvm_ffi / recurrent_kda / cutedsl_kda_decode layers
into a single cutedsl_kda_decode function with internal _dispatch_kernel
and _compile_tvm_ffi helpers. Remove unused RecurrentKDAKernel class.
- Tighten test tolerances from 1e-1/5e-2 to 5e-3/5e-3 (yzh119)
- Add HV >= H and HV % H == 0 guard before dispatch (CodeRabbit)
- Respect caller-provided output buffer in cu_seqlens path (CodeRabbit)
- Restrict benchmark CLI to T=1 only (CodeRabbit)
@djmmoss djmmoss changed the title feat(kda): add KDA decode CuTe DSL kernel with per-K gating feat(kda): add recurrent KDA decode kernel with per-K gating Feb 26, 2026
The recurrent KDA kernel requires SM100a (Blackwell). Add module-level
pytest.skip using is_sm100a_supported() so CI on Hopper/Ampere skips
gracefully instead of failing at CuTe DSL compilation.
Module-level pytest.skip causes 0 tests collected, which CI treats as
failure. Switch to per-test _skip_if_not_sm100() calls matching the
pattern used by GDN tests, so tests are collected and individually
skipped on non-SM100 architectures.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants