-
Notifications
You must be signed in to change notification settings - Fork 829
feat(kda): add recurrent KDA decode kernel with per-K gating #2572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
djmmoss
wants to merge
81
commits into
flashinfer-ai:main
Choose a base branch
from
djmmoss:kda-decode-cutedsl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
81 commits
Select commit
Hold shift + click to select a range
16f45f6
Add workflow to sync fork with upstream hourly
djmmoss 42a6303
Add workflows permission for syncing upstream workflow files
djmmoss f7510a9
Use PAT_TOKEN for workflow write permission
djmmoss 854cd91
Merge branch 'flashinfer-ai:main' into main
djmmoss e6f9d9b
Merge branch 'flashinfer-ai:main' into main
djmmoss fbb4cb9
Merge branch 'flashinfer-ai:main' into main
djmmoss 2f0f7a8
Merge branch 'flashinfer-ai:main' into main
djmmoss 0b28582
Merge branch 'flashinfer-ai:main' into main
djmmoss cd37078
Merge branch 'flashinfer-ai:main' into main
djmmoss 2c13497
Merge branch 'flashinfer-ai:main' into main
djmmoss a0c2f28
Merge branch 'flashinfer-ai:main' into main
djmmoss 514504d
Merge branch 'flashinfer-ai:main' into main
djmmoss 0f633a4
Merge branch 'flashinfer-ai:main' into main
djmmoss 351a34e
Merge branch 'flashinfer-ai:main' into main
djmmoss 35617a2
Merge branch 'flashinfer-ai:main' into main
djmmoss ac6c98b
Merge branch 'flashinfer-ai:main' into main
djmmoss e225c2d
Merge branch 'flashinfer-ai:main' into main
djmmoss 91ea3cf
Merge branch 'flashinfer-ai:main' into main
djmmoss dfa4f6f
Merge branch 'flashinfer-ai:main' into main
djmmoss ce6f822
Merge branch 'flashinfer-ai:main' into main
djmmoss c084272
Merge branch 'flashinfer-ai:main' into main
djmmoss 0abebff
Merge branch 'flashinfer-ai:main' into main
djmmoss abc18a6
Merge branch 'flashinfer-ai:main' into main
djmmoss 255fe17
Merge branch 'flashinfer-ai:main' into main
djmmoss f1438f4
Merge branch 'flashinfer-ai:main' into main
djmmoss 31a2ac6
Merge branch 'flashinfer-ai:main' into main
djmmoss 3779148
Merge branch 'flashinfer-ai:main' into main
djmmoss 294d727
Merge branch 'flashinfer-ai:main' into main
djmmoss 89bb2f7
Merge branch 'flashinfer-ai:main' into main
djmmoss 4da18bf
Merge branch 'flashinfer-ai:main' into main
djmmoss bd43b41
Merge branch 'flashinfer-ai:main' into main
djmmoss 4d596ac
Merge branch 'flashinfer-ai:main' into main
djmmoss 75f7b3d
Merge branch 'flashinfer-ai:main' into main
djmmoss 43d6280
Merge branch 'flashinfer-ai:main' into main
djmmoss e2fca86
Merge branch 'flashinfer-ai:main' into main
djmmoss f562755
Merge branch 'flashinfer-ai:main' into main
djmmoss 270cbcc
Merge branch 'flashinfer-ai:main' into main
djmmoss 4cde58d
Merge branch 'flashinfer-ai:main' into main
djmmoss a3c0510
Merge branch 'flashinfer-ai:main' into main
djmmoss a984769
Merge branch 'flashinfer-ai:main' into main
djmmoss 8475bb6
Merge branch 'flashinfer-ai:main' into main
djmmoss 24e38b6
Merge branch 'flashinfer-ai:main' into main
djmmoss 07dbc92
Merge branch 'flashinfer-ai:main' into main
djmmoss d2288c8
Merge branch 'flashinfer-ai:main' into main
djmmoss c74c181
Merge branch 'flashinfer-ai:main' into main
djmmoss 298448b
Merge branch 'flashinfer-ai:main' into main
djmmoss 0226374
Merge branch 'flashinfer-ai:main' into main
djmmoss 9b18e4a
Merge branch 'flashinfer-ai:main' into main
djmmoss ecca1a8
Merge branch 'flashinfer-ai:main' into main
djmmoss 4c47cfa
Merge branch 'flashinfer-ai:main' into main
djmmoss 45e1806
Merge branch 'flashinfer-ai:main' into main
djmmoss 88ed6fb
Merge branch 'flashinfer-ai:main' into main
djmmoss fffe73c
Merge branch 'flashinfer-ai:main' into main
djmmoss ced9c58
Merge branch 'flashinfer-ai:main' into main
djmmoss e5061ac
Merge branch 'flashinfer-ai:main' into main
djmmoss 1c5fbb7
Merge branch 'flashinfer-ai:main' into main
djmmoss 3771194
Merge branch 'flashinfer-ai:main' into main
djmmoss 8b3d017
Merge branch 'flashinfer-ai:main' into main
djmmoss 3e1c22a
Merge branch 'flashinfer-ai:main' into main
djmmoss b95d321
Merge branch 'flashinfer-ai:main' into main
djmmoss fb1eea9
feat(kda): add KDA decode CuTe DSL kernel with per-K gating
djmmoss 8aff823
fix: correct DLPack cache key and stale cache comments
djmmoss 044417e
fix: address CodeRabbit review feedback
djmmoss 5608f7d
fix: use correct head counts in KDA benchmark byte/flop accounting
djmmoss b728afc
fix: tighten KDA test tolerances to match GDN decode
djmmoss 78c848b
refactor(kda): replace KDA kernel with recurrent T=1 variant, revert …
djmmoss 7c74427
chore: remove fork-specific sync-upstream workflow from PR
djmmoss b42c6b1
refactor(kda): rename to recurrent_kda.py, merge into single public API
djmmoss 2261359
refactor(kda): rename cutedsl_kda_decode -> recurrent_kda
djmmoss 09a0eb2
docs(kda): update recurrent_kda docstring to match GDN style
djmmoss 125fdbb
refactor(kda): use @functools.cache for kernel compilation, inline di…
djmmoss 0baa523
fix: address PR review feedback
djmmoss 7761a34
refactor(kda): rename test functions and comments to match recurrent_…
djmmoss 61af50a
chore(kda): remove stale cutedsl_kda_decode.py re-export shim
djmmoss c367b2d
rename: bench_kda_decode -> bench_recurrent_kda, test_decode_kda -> t…
djmmoss c8069f2
refactor: rename internal benchmark functions to match recurrent_kda …
djmmoss 038fb80
refactor: rename KDA_DECODE_AVAILABLE -> RECURRENT_KDA_AVAILABLE
djmmoss 9a04856
refactor: replace remaining 'KDA decode' references in benchmark
djmmoss f7cd392
style: apply ruff format
djmmoss d77d9d8
fix(kda): skip recurrent KDA tests on non-SM100 architectures
djmmoss 8e2009e
fix(kda): use per-test skip instead of module-level skip for SM100 guard
djmmoss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,337 @@ | ||
| """ | ||
| Copyright (c) 2025 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| """ | ||
| Recurrent KDA (Key-Driven Attention) Benchmark | ||
|
|
||
| Benchmarks the recurrent KDA kernel with per-K-dimension gating. | ||
| KDA differs from GDN by having gate g[B, T, HV, K] instead of a scalar gate. | ||
|
|
||
| Usage: | ||
| python benchmarks/bench_recurrent_kda.py --batch-size 1 4 16 64 128 256 | ||
| python benchmarks/bench_recurrent_kda.py --head-size 64 --batch-size 1 32 128 | ||
| python benchmarks/bench_recurrent_kda.py --seq-len 1 2 3 4 --batch-size 1 32 | ||
| """ | ||
|
|
||
| import argparse | ||
| import numpy as np | ||
| import torch | ||
|
|
||
| from flashinfer.testing import bench_gpu_time | ||
|
|
||
| # Import the recurrent KDA kernel | ||
| try: | ||
| from flashinfer.kda_kernels import recurrent_kda | ||
|
|
||
| RECURRENT_KDA_AVAILABLE = True | ||
| except ImportError: | ||
| RECURRENT_KDA_AVAILABLE = False | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # FLOPs and Bytes Calculation | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| def recurrent_kda_flops( | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| _num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| seq_len: int = 1, | ||
| ) -> int: | ||
| """ | ||
| Calculate FLOPs for KDA (Key-Driven Attention) decode. | ||
|
|
||
| 8 * K * V FLOPs per token per head: | ||
| 1. k @ state (prediction): 2 * K * V | ||
| 2. k^T @ v_new (update): 2 * K * V | ||
| 3. q @ state (output): 2 * K * V | ||
| 4. Per-K gate application: 2 * K * V (K*V element-wise multiply + K exp() calls) | ||
|
|
||
| Note: K = V = head_size for KDA. State ops are per-HV (value) head. | ||
| """ | ||
| total_flops = 8 * seq_len * batch_size * num_v_heads * head_size * head_size | ||
| return total_flops | ||
|
|
||
|
|
||
| def recurrent_kda_bytes( | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| dtype: torch.dtype, | ||
| seq_len: int = 1, | ||
| ) -> int: | ||
| """ | ||
| Calculate memory bytes for recurrent KDA. | ||
|
|
||
| Includes: | ||
| - Q, K, V tensors: [B, T, H, K] - dtype | ||
| - G tensor (per-K gate): [B, T, HV, K] - dtype (extra vs GDN) | ||
| - Beta: [B, T, HV] - dtype | ||
| - State (read + write): [B, HV, V, K] - bf16 (2 bytes) | ||
| - Output: [B, T, HV, V] - dtype | ||
| """ | ||
| elem_size = torch.tensor([], dtype=dtype).element_size() | ||
| state_dtype_bytes = 2 # BF16 state | ||
|
|
||
| # Input tensors: q/k use H (query heads), v uses HV (value heads) | ||
| q_bytes = batch_size * seq_len * num_q_heads * head_size * elem_size | ||
| k_bytes = batch_size * seq_len * num_k_heads * head_size * elem_size | ||
| v_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size | ||
|
|
||
| # Per-K gate: [B, T, HV, K] | ||
| g_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size | ||
|
|
||
| # Beta: [B, T, HV] | ||
| beta_bytes = batch_size * seq_len * num_v_heads * elem_size | ||
|
|
||
| # Output: [B, T, HV, V] | ||
| o_bytes = batch_size * seq_len * num_v_heads * head_size * elem_size | ||
|
|
||
| # State: [B, HV, V, K] read + write | ||
| state_bytes = ( | ||
| 2 * batch_size * num_v_heads * head_size * head_size * state_dtype_bytes | ||
| ) | ||
|
|
||
| total_bytes = ( | ||
| q_bytes + k_bytes + v_bytes + g_bytes + beta_bytes + o_bytes + state_bytes | ||
| ) | ||
| return total_bytes | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Benchmark Function | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| def bench_recurrent_kda( | ||
| batch_size: int, | ||
| seq_len: int, | ||
| num_q_heads: int, | ||
| num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| dtype: torch.dtype, | ||
| warmup_iters: int = 10, | ||
| bench_iters: int = 100, | ||
| ): | ||
| """Benchmark recurrent KDA kernel for T=1.""" | ||
| if not RECURRENT_KDA_AVAILABLE: | ||
| raise RuntimeError("recurrent KDA kernel is not available") | ||
|
|
||
| assert seq_len == 1, f"recurrent KDA supports T=1 only, got T={seq_len}" | ||
|
|
||
| # Create inputs | ||
| T = seq_len | ||
| q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") | ||
| k = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") | ||
| v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") | ||
|
|
||
| # KDA-specific: per-K log-space gate [B, T, HV, K] | ||
| g = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") | ||
|
|
||
| # Beta: [B, T, HV] (pre-sigmoided) | ||
| beta = torch.randn(batch_size, T, num_v_heads, dtype=dtype, device="cuda") | ||
|
|
||
| # Initial state: [B, HV, V, K] (K-last layout, BF16) | ||
| state = torch.randn( | ||
| batch_size, | ||
| num_v_heads, | ||
| head_size, | ||
| head_size, | ||
| dtype=torch.bfloat16, | ||
| device="cuda", | ||
| ) | ||
|
|
||
| # Scale factor | ||
| scale = 1.0 / (head_size**0.5) | ||
|
|
||
| # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) | ||
| kernel_times_ms = bench_gpu_time( | ||
| lambda: recurrent_kda( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=state, | ||
| scale=scale, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ), | ||
| enable_cupti=True, | ||
| dry_run_iters=warmup_iters, | ||
| repeat_iters=bench_iters, | ||
| ) | ||
|
|
||
| # Calculate metrics | ||
| kernel_median_ms = np.median(kernel_times_ms) | ||
| flops = recurrent_kda_flops( | ||
| batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len | ||
| ) | ||
| bytes_accessed = recurrent_kda_bytes( | ||
| batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, dtype, seq_len | ||
| ) | ||
|
|
||
| kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 | ||
| kernel_tb_per_sec = ( | ||
| bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 | ||
| ) | ||
|
|
||
| return { | ||
| "batch_size": batch_size, | ||
| "seq_len": seq_len, | ||
| "kernel_median_us": kernel_median_ms * 1000, | ||
| "kernel_tflops": kernel_tflops, | ||
| "kernel_tb_per_sec": kernel_tb_per_sec, | ||
| } | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Runner | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| def run_recurrent_kda_benchmark(args, dtype): | ||
| """Run recurrent KDA benchmark for T=1.""" | ||
| if not RECURRENT_KDA_AVAILABLE: | ||
| print("Error: recurrent KDA kernel is not available.") | ||
| print("Make sure flashinfer.kda_kernels.recurrent_kda is importable.") | ||
| return | ||
|
|
||
| # Filter seq_len to only valid values (T=1 only) | ||
| valid_seq_lens = [t for t in args.seq_len if t == 1] | ||
| if not valid_seq_lens: | ||
| print("Error: --seq-len must include 1 (kernel supports T=1 only)") | ||
| return | ||
|
|
||
| print("\n" + "=" * 100) | ||
| print(f"Recurrent KDA Benchmark (T={valid_seq_lens})") | ||
| print( | ||
| f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " | ||
| f"v_heads={args.num_v_heads}, head_size={args.head_size}, " | ||
| f"dtype={args.dtype}" | ||
| ) | ||
| print("=" * 100) | ||
| print() | ||
| print(f"{'batch':>6} {'T':>4} {'time(us)':>10} {'TFLOPS':>10} {'TB/s':>10}") | ||
| print("-" * 100) | ||
|
|
||
| all_results = [] | ||
| for batch_size in args.batch_size: | ||
| for seq_len in valid_seq_lens: | ||
| try: | ||
| result = bench_recurrent_kda( | ||
| batch_size=batch_size, | ||
| seq_len=seq_len, | ||
| num_q_heads=args.num_q_heads, | ||
| num_k_heads=args.num_k_heads, | ||
| num_v_heads=args.num_v_heads, | ||
| head_size=args.head_size, | ||
| dtype=dtype, | ||
| warmup_iters=args.warmup, | ||
| bench_iters=args.iters, | ||
| ) | ||
| all_results.append(result) | ||
|
|
||
| print( | ||
| f"{result['batch_size']:>6} {result['seq_len']:>4} " | ||
| f"{result['kernel_median_us']:>10.2f} " | ||
| f"{result['kernel_tflops']:>10.2f} " | ||
| f"{result['kernel_tb_per_sec']:>10.2f}" | ||
| ) | ||
| except Exception as e: | ||
| print( | ||
| f"{batch_size:>6} {seq_len:>4} {'ERROR':>10} - {type(e).__name__}: {e}" | ||
| ) | ||
|
|
||
| print("-" * 100) | ||
| print() | ||
|
|
||
| # Summary by T value | ||
| for t in valid_seq_lens: | ||
| t_results = [r for r in all_results if r["seq_len"] == t] | ||
| if t_results: | ||
| avg_time = np.mean([r["kernel_median_us"] for r in t_results]) | ||
| avg_tflops = np.mean([r["kernel_tflops"] for r in t_results]) | ||
| print( | ||
| f"T={t}: Average time={avg_time:.2f}us, Average TFLOPS={avg_tflops:.2f}" | ||
| ) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Main | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Recurrent KDA Benchmark", | ||
| formatter_class=argparse.RawDescriptionHelpFormatter, | ||
| epilog=""" | ||
| Examples: | ||
| python benchmarks/bench_recurrent_kda.py --batch-size 1 4 16 64 128 256 | ||
| python benchmarks/bench_recurrent_kda.py --head-size 64 --batch-size 1 32 128 | ||
| python benchmarks/bench_recurrent_kda.py --seq-len 1 2 3 4 --batch-size 1 32 | ||
| """, | ||
| ) | ||
| parser.add_argument( | ||
| "--batch-size", | ||
| type=int, | ||
| nargs="+", | ||
| default=[1, 4, 16, 64, 128, 256], | ||
| help="Batch sizes to benchmark", | ||
| ) | ||
| parser.add_argument("--num-q-heads", type=int, default=16) | ||
| parser.add_argument("--num-k-heads", type=int, default=16) | ||
| parser.add_argument("--num-v-heads", type=int, default=32) | ||
| parser.add_argument("--head-size", type=int, default=128, choices=[64, 128]) | ||
| parser.add_argument( | ||
| "--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16" | ||
| ) | ||
| parser.add_argument( | ||
| "--seq-len", | ||
| type=int, | ||
| nargs="+", | ||
| default=[1], | ||
| help="Sequence length (T=1 only)", | ||
| ) | ||
| parser.add_argument( | ||
| "--warmup", | ||
| type=int, | ||
| default=10, | ||
| help="Number of warmup iterations", | ||
| ) | ||
| parser.add_argument( | ||
| "--iters", | ||
| type=int, | ||
| default=100, | ||
| help="Number of benchmark iterations", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| # Resolve dtype | ||
| dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16} | ||
| dtype = dtype_map[args.dtype] | ||
|
|
||
| run_recurrent_kda_benchmark(args, dtype) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| """ | ||
| Copyright (c) 2025 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| """ | ||
| KDA (Key-Driven Attention) Kernels - CuTe DSL Implementations | ||
| ============================================================== | ||
|
|
||
| Per-K-dimension gating variant of GDN. Gate g[B,T,HV,K] applied per-lane | ||
| instead of GDN's scalar broadcast. | ||
|
|
||
| Exported: | ||
| - recurrent_kda: Recurrent KDA decode kernel (T=1) | ||
| """ | ||
|
|
||
| try: | ||
| from .recurrent_kda import recurrent_kda | ||
|
|
||
| _has_cute_dsl = True | ||
| except ImportError: | ||
| _has_cute_dsl = False | ||
| recurrent_kda = None # type: ignore | ||
|
|
||
| __all__ = [ | ||
| "recurrent_kda", | ||
| ] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.