Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
16f45f6
Add workflow to sync fork with upstream hourly
djmmoss Feb 3, 2026
42a6303
Add workflows permission for syncing upstream workflow files
djmmoss Feb 3, 2026
f7510a9
Use PAT_TOKEN for workflow write permission
djmmoss Feb 3, 2026
854cd91
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 3, 2026
e6f9d9b
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 4, 2026
fbb4cb9
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 4, 2026
2f0f7a8
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 4, 2026
0b28582
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 4, 2026
cd37078
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 4, 2026
2c13497
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 4, 2026
a0c2f28
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 5, 2026
514504d
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 5, 2026
0f633a4
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 5, 2026
351a34e
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 6, 2026
35617a2
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 6, 2026
ac6c98b
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 7, 2026
e225c2d
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 8, 2026
91ea3cf
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 9, 2026
dfa4f6f
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 11, 2026
ce6f822
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 11, 2026
c084272
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 11, 2026
0abebff
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 12, 2026
abc18a6
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 12, 2026
255fe17
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 12, 2026
f1438f4
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 12, 2026
31a2ac6
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 13, 2026
3779148
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 13, 2026
294d727
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 13, 2026
89bb2f7
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 14, 2026
4da18bf
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 16, 2026
bd43b41
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 16, 2026
4d596ac
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 16, 2026
75f7b3d
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 16, 2026
43d6280
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 17, 2026
e2fca86
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 17, 2026
f562755
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 17, 2026
270cbcc
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 18, 2026
4cde58d
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 18, 2026
a3c0510
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 19, 2026
a984769
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 19, 2026
8475bb6
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 19, 2026
24e38b6
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 20, 2026
07dbc92
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 20, 2026
d2288c8
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 20, 2026
c74c181
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 20, 2026
298448b
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 20, 2026
0226374
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 21, 2026
9b18e4a
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 21, 2026
ecca1a8
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 22, 2026
4c47cfa
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 23, 2026
45e1806
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 23, 2026
88ed6fb
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 23, 2026
fffe73c
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 24, 2026
ced9c58
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 25, 2026
e5061ac
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 25, 2026
1c5fbb7
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 25, 2026
3771194
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 25, 2026
8b3d017
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 25, 2026
3e1c22a
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 26, 2026
b95d321
Merge branch 'flashinfer-ai:main' into main
djmmoss Feb 26, 2026
fb1eea9
feat(kda): add KDA decode CuTe DSL kernel with per-K gating
djmmoss Feb 17, 2026
8aff823
fix: correct DLPack cache key and stale cache comments
djmmoss Feb 17, 2026
044417e
fix: address CodeRabbit review feedback
djmmoss Feb 17, 2026
5608f7d
fix: use correct head counts in KDA benchmark byte/flop accounting
djmmoss Feb 17, 2026
b728afc
fix: tighten KDA test tolerances to match GDN decode
djmmoss Feb 19, 2026
78c848b
refactor(kda): replace KDA kernel with recurrent T=1 variant, revert …
djmmoss Feb 26, 2026
7c74427
chore: remove fork-specific sync-upstream workflow from PR
djmmoss Feb 26, 2026
b42c6b1
refactor(kda): rename to recurrent_kda.py, merge into single public API
djmmoss Feb 26, 2026
2261359
refactor(kda): rename cutedsl_kda_decode -> recurrent_kda
djmmoss Feb 26, 2026
09a0eb2
docs(kda): update recurrent_kda docstring to match GDN style
djmmoss Feb 26, 2026
125fdbb
refactor(kda): use @functools.cache for kernel compilation, inline di…
djmmoss Feb 26, 2026
0baa523
fix: address PR review feedback
djmmoss Feb 26, 2026
7761a34
refactor(kda): rename test functions and comments to match recurrent_…
djmmoss Feb 26, 2026
61af50a
chore(kda): remove stale cutedsl_kda_decode.py re-export shim
djmmoss Feb 26, 2026
c367b2d
rename: bench_kda_decode -> bench_recurrent_kda, test_decode_kda -> t…
djmmoss Feb 26, 2026
c8069f2
refactor: rename internal benchmark functions to match recurrent_kda …
djmmoss Feb 26, 2026
038fb80
refactor: rename KDA_DECODE_AVAILABLE -> RECURRENT_KDA_AVAILABLE
djmmoss Feb 26, 2026
9a04856
refactor: replace remaining 'KDA decode' references in benchmark
djmmoss Feb 26, 2026
f7cd392
style: apply ruff format
djmmoss Feb 26, 2026
d77d9d8
fix(kda): skip recurrent KDA tests on non-SM100 architectures
djmmoss Feb 26, 2026
8e2009e
fix(kda): use per-test skip instead of module-level skip for SM100 guard
djmmoss Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 337 additions & 0 deletions benchmarks/bench_recurrent_kda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
Recurrent KDA (Key-Driven Attention) Benchmark

Benchmarks the recurrent KDA kernel with per-K-dimension gating.
KDA differs from GDN by having gate g[B, T, HV, K] instead of a scalar gate.

Usage:
python benchmarks/bench_recurrent_kda.py --batch-size 1 4 16 64 128 256
python benchmarks/bench_recurrent_kda.py --head-size 64 --batch-size 1 32 128
python benchmarks/bench_recurrent_kda.py --seq-len 1 2 3 4 --batch-size 1 32
"""

import argparse
import numpy as np
import torch

from flashinfer.testing import bench_gpu_time

# Import the recurrent KDA kernel
try:
from flashinfer.kda_kernels import recurrent_kda

RECURRENT_KDA_AVAILABLE = True
except ImportError:
RECURRENT_KDA_AVAILABLE = False


# ============================================================================
# FLOPs and Bytes Calculation
# ============================================================================


def recurrent_kda_flops(
batch_size: int,
num_q_heads: int,
_num_k_heads: int,
num_v_heads: int,
head_size: int,
seq_len: int = 1,
) -> int:
"""
Calculate FLOPs for KDA (Key-Driven Attention) decode.

8 * K * V FLOPs per token per head:
1. k @ state (prediction): 2 * K * V
2. k^T @ v_new (update): 2 * K * V
3. q @ state (output): 2 * K * V
4. Per-K gate application: 2 * K * V (K*V element-wise multiply + K exp() calls)

Note: K = V = head_size for KDA. State ops are per-HV (value) head.
"""
total_flops = 8 * seq_len * batch_size * num_v_heads * head_size * head_size
return total_flops


def recurrent_kda_bytes(
batch_size: int,
num_q_heads: int,
num_k_heads: int,
num_v_heads: int,
head_size: int,
dtype: torch.dtype,
seq_len: int = 1,
) -> int:
"""
Calculate memory bytes for recurrent KDA.

Includes:
- Q, K, V tensors: [B, T, H, K] - dtype
- G tensor (per-K gate): [B, T, HV, K] - dtype (extra vs GDN)
- Beta: [B, T, HV] - dtype
- State (read + write): [B, HV, V, K] - bf16 (2 bytes)
- Output: [B, T, HV, V] - dtype
"""
elem_size = torch.tensor([], dtype=dtype).element_size()
state_dtype_bytes = 2 # BF16 state

# Input tensors: q/k use H (query heads), v uses HV (value heads)
q_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size
k_bytes = batch_size * seq_len * num_k_heads * head_size * elem_size
v_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size

# Per-K gate: [B, T, HV, K]
g_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size

# Beta: [B, T, HV]
beta_bytes = batch_size * seq_len * num_v_heads * elem_size

# Output: [B, T, HV, V]
o_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size

# State: [B, HV, V, K] read + write
state_bytes = (
2 * batch_size * num_v_heads * head_size * head_size * state_dtype_bytes
)

total_bytes = (
q_bytes + k_bytes + v_bytes + g_bytes + beta_bytes + o_bytes + state_bytes
)
return total_bytes


# ============================================================================
# Benchmark Function
# ============================================================================


def bench_recurrent_kda(
batch_size: int,
seq_len: int,
num_q_heads: int,
num_k_heads: int,
num_v_heads: int,
head_size: int,
dtype: torch.dtype,
warmup_iters: int = 10,
bench_iters: int = 100,
):
"""Benchmark recurrent KDA kernel for T=1."""
if not RECURRENT_KDA_AVAILABLE:
raise RuntimeError("recurrent KDA kernel is not available")

assert seq_len == 1, f"recurrent KDA supports T=1 only, got T={seq_len}"

# Create inputs
T = seq_len
q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda")
k = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda")
v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda")

# KDA-specific: per-K log-space gate [B, T, HV, K]
g = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda")

# Beta: [B, T, HV] (pre-sigmoided)
beta = torch.randn(batch_size, T, num_v_heads, dtype=dtype, device="cuda")

# Initial state: [B, HV, V, K] (K-last layout, BF16)
state = torch.randn(
batch_size,
num_v_heads,
head_size,
head_size,
dtype=torch.bfloat16,
device="cuda",
)

# Scale factor
scale = 1.0 / (head_size**0.5)

# Benchmark with bench_gpu_time (CUPTI for accurate kernel timing)
kernel_times_ms = bench_gpu_time(
lambda: recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
scale=scale,
use_qk_l2norm_in_kernel=True,
),
enable_cupti=True,
dry_run_iters=warmup_iters,
repeat_iters=bench_iters,
)

# Calculate metrics
kernel_median_ms = np.median(kernel_times_ms)
flops = recurrent_kda_flops(
batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len
)
bytes_accessed = recurrent_kda_bytes(
batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, dtype, seq_len
)

kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0
kernel_tb_per_sec = (
bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0
)

return {
"batch_size": batch_size,
"seq_len": seq_len,
"kernel_median_us": kernel_median_ms * 1000,
"kernel_tflops": kernel_tflops,
"kernel_tb_per_sec": kernel_tb_per_sec,
}


# ============================================================================
# Runner
# ============================================================================


def run_recurrent_kda_benchmark(args, dtype):
"""Run recurrent KDA benchmark for T=1."""
if not RECURRENT_KDA_AVAILABLE:
print("Error: recurrent KDA kernel is not available.")
print("Make sure flashinfer.kda_kernels.recurrent_kda is importable.")
return

# Filter seq_len to only valid values (T=1 only)
valid_seq_lens = [t for t in args.seq_len if t == 1]
if not valid_seq_lens:
print("Error: --seq-len must include 1 (kernel supports T=1 only)")
return

print("\n" + "=" * 100)
print(f"Recurrent KDA Benchmark (T={valid_seq_lens})")
print(
f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, "
f"v_heads={args.num_v_heads}, head_size={args.head_size}, "
f"dtype={args.dtype}"
)
print("=" * 100)
print()
print(f"{'batch':>6} {'T':>4} {'time(us)':>10} {'TFLOPS':>10} {'TB/s':>10}")
print("-" * 100)

all_results = []
for batch_size in args.batch_size:
for seq_len in valid_seq_lens:
try:
result = bench_recurrent_kda(
batch_size=batch_size,
seq_len=seq_len,
num_q_heads=args.num_q_heads,
num_k_heads=args.num_k_heads,
num_v_heads=args.num_v_heads,
head_size=args.head_size,
dtype=dtype,
warmup_iters=args.warmup,
bench_iters=args.iters,
)
all_results.append(result)

print(
f"{result['batch_size']:>6} {result['seq_len']:>4} "
f"{result['kernel_median_us']:>10.2f} "
f"{result['kernel_tflops']:>10.2f} "
f"{result['kernel_tb_per_sec']:>10.2f}"
)
except Exception as e:
print(
f"{batch_size:>6} {seq_len:>4} {'ERROR':>10} - {type(e).__name__}: {e}"
)

print("-" * 100)
print()

# Summary by T value
for t in valid_seq_lens:
t_results = [r for r in all_results if r["seq_len"] == t]
if t_results:
avg_time = np.mean([r["kernel_median_us"] for r in t_results])
avg_tflops = np.mean([r["kernel_tflops"] for r in t_results])
print(
f"T={t}: Average time={avg_time:.2f}us, Average TFLOPS={avg_tflops:.2f}"
)


# ============================================================================
# Main
# ============================================================================


def main():
parser = argparse.ArgumentParser(
description="Recurrent KDA Benchmark",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python benchmarks/bench_recurrent_kda.py --batch-size 1 4 16 64 128 256
python benchmarks/bench_recurrent_kda.py --head-size 64 --batch-size 1 32 128
python benchmarks/bench_recurrent_kda.py --seq-len 1 2 3 4 --batch-size 1 32
""",
)
parser.add_argument(
"--batch-size",
type=int,
nargs="+",
default=[1, 4, 16, 64, 128, 256],
help="Batch sizes to benchmark",
)
parser.add_argument("--num-q-heads", type=int, default=16)
parser.add_argument("--num-k-heads", type=int, default=16)
parser.add_argument("--num-v-heads", type=int, default=32)
parser.add_argument("--head-size", type=int, default=128, choices=[64, 128])
parser.add_argument(
"--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16"
)
parser.add_argument(
"--seq-len",
type=int,
nargs="+",
default=[1],
help="Sequence length (T=1 only)",
)
parser.add_argument(
"--warmup",
type=int,
default=10,
help="Number of warmup iterations",
)
parser.add_argument(
"--iters",
type=int,
default=100,
help="Number of benchmark iterations",
)
args = parser.parse_args()

# Resolve dtype
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16}
dtype = dtype_map[args.dtype]

run_recurrent_kda_benchmark(args, dtype)


if __name__ == "__main__":
main()
38 changes: 38 additions & 0 deletions flashinfer/kda_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
KDA (Key-Driven Attention) Kernels - CuTe DSL Implementations
==============================================================

Per-K-dimension gating variant of GDN. Gate g[B,T,HV,K] applied per-lane
instead of GDN's scalar broadcast.

Exported:
- recurrent_kda: Recurrent KDA decode kernel (T=1)
"""

try:
from .recurrent_kda import recurrent_kda

_has_cute_dsl = True
except ImportError:
_has_cute_dsl = False
recurrent_kda = None # type: ignore

__all__ = [
"recurrent_kda",
]
Loading
Loading