perf: improve gdn decode cute-dsl kernels#2405
Conversation
1. Remove unnecessary state.copy_() when state is already contiguous: - For contiguous state, h0_source shares memory with state - Kernel updates state in-place, so copy_ is redundant 2. Cache h0_indices and cu_seqlens tensors: - These tensors have fixed values based on batch size - Reuse from cache instead of creating new tensors each call Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Use bench_gpu_time with CUPTI for GDN decode benchmark Replace torch.profiler with flashinfer.testing.bench_gpu_time: - More accurate kernel timing via CUPTI hardware profiling - Simpler code without trace file parsing - Consistent with other FlashInfer benchmarks Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Enable fastmath for exp/log/sqrt/rsqrt in GDN decode kernels Use fastmath=True for all cute.exp, cute.log, cute.sqrt, cute.rsqrt calls to enable faster approximate math intrinsics. ~1-2% improvement observed in small batch sizes. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Use rsqrt instead of sqrt+division in GDN decode L2 norm Replace `norm = sqrt(sum); x = x / norm` with `inv_norm = rsqrt(sum); x = x * inv_norm` to eliminate a division instruction. The MTP kernel already used this pattern. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Resolved conflicts by keeping GDN decode optimizations: - fastmath=True for exp/log/rsqrt operations - rsqrt instead of sqrt+division for L2 norm - Cached h0_indices and cu_seqlens to avoid repeated allocation - Optimized state.copy_ to only run when necessary - Use bench_gpu_time (CUPTI) for benchmarks instead of torch.profiler Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
📝 WalkthroughWalkthroughAdded Triton-based GDN decode and MTP kernels and benchmarking flows, replaced trace timing with CUPTI-backed GPU timing, added comparison and correctness-verification paths between FlashInfer and Triton, and made MTP tile/vec sizes dynamic with per-config caching and adjusted state-layout handling. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant BenchHarness
participant FlashInfer
participant TritonKernel
participant GPU
User->>BenchHarness: start benchmark (mode: compare / verify)
BenchHarness->>FlashInfer: run GDN decode / MTP (FlashInfer path)
FlashInfer->>GPU: launch FlashInfer kernel
GPU-->>FlashInfer: results + CUPTI timings
BenchHarness->>TritonKernel: run Triton GDN decode / MTP
TritonKernel->>GPU: launch Triton kernel
GPU-->>TritonKernel: results + CUPTI timings
BenchHarness->>BenchHarness: verify outputs and states
BenchHarness-->>User: report timings and verification outcome
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the performance and benchmarking capabilities of the GDN decode kernels. It introduces a comprehensive benchmarking framework that leverages CUPTI for precise kernel timing and enables direct performance comparisons against external Triton-based implementations. Concurrently, the core GDN CuTe DSL kernels have been optimized through Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request significantly improves the performance and benchmarking capabilities for the GDN decode kernels. The optimizations in the CuTe kernels, such as enabling fastmath and replacing division with multiplication via rsqrt, are well-implemented and should provide a good performance boost. The benchmark script has been substantially enhanced with the addition of a Triton kernel for baseline comparison, correctness verification, and a more accurate timing mechanism using bench_gpu_time. The code is well-structured and the new benchmarking features are very valuable. I have a couple of minor suggestions to further optimize the new Triton kernels for consistency with the performance goals of this PR.
| q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8) | ||
| k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8) | ||
| b_q = b_q / q_norm | ||
| b_k = b_k / k_norm |
There was a problem hiding this comment.
For better performance, it's recommended to use tl.rsqrt and multiplication for normalization instead of tl.sqrt and division. This is consistent with the optimizations applied to the CuTe kernels in this PR and is a standard practice for performance-critical GPU code.
| q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8) | |
| k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8) | |
| b_q = b_q / q_norm | |
| b_k = b_k / k_norm | |
| q_inv_norm = tl.rsqrt(tl.sum(b_q * b_q) + 1e-8) | |
| k_inv_norm = tl.rsqrt(tl.sum(b_k * b_k) + 1e-8) | |
| b_q = b_q * q_inv_norm | |
| b_k = b_k * k_inv_norm |
| q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8) | ||
| k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8) | ||
| b_q = b_q / q_norm | ||
| b_k = b_k / k_norm |
There was a problem hiding this comment.
Similar to the other Triton kernel, using tl.rsqrt and multiplication for normalization will be more performant than tl.sqrt and division. This change would align the Triton baseline with the optimization principles used elsewhere in this PR.
| q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8) | |
| k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8) | |
| b_q = b_q / q_norm | |
| b_k = b_k / k_norm | |
| q_inv_norm = tl.rsqrt(tl.sum(b_q * b_q) + 1e-8) | |
| k_inv_norm = tl.rsqrt(tl.sum(b_k * b_k) + 1e-8) | |
| b_q = b_q * q_inv_norm | |
| b_k = b_k * k_inv_norm |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1926-1929: The current try/except around the batch verification
(the block that sets status = "PASS" if passed else "FAIL" and prints
Batch={batch_size}) should not catch broad Exception; replace the generic
"except Exception as e" with explicit handlers for the expected verification
errors (e.g., except AssertionError as e, except ValueError as e, except
RuntimeError as e) and handle each by printing Batch={batch_size}: ERROR -
<type>, and keep an optional final except: raise to avoid swallowing unknown
errors; reference the variables/status logic in this try/except block (status,
passed, batch_size) when updating the handlers.
🧹 Nitpick comments (4)
flashinfer/gdn_decode.py (2)
945-950: Consider adding device to cache key for multi-GPU scenarios.The device check handles device mismatches by creating new tensors, but the cache key (
cache_keyat line 942) doesn't include the device. If the same configuration is used across multiple GPUs, the cached tensors will be repeatedly recreated.Consider either:
- Adding
q.deviceto the cache key, or- Using a nested dict keyed by device within the cache
♻️ Suggested improvement
# Option 1: Include device in cache key cache_key = (B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm, q.device.index)Or store per-device tensors in the cache:
device_key = q.device.index if "h0_indices" not in cache: cache["h0_indices"] = {} cache["cu_seqlens"] = {} if device_key not in cache["h0_indices"]: cache["h0_indices"][device_key] = torch.zeros(B, dtype=torch.int32, device=q.device) cache["cu_seqlens"][device_key] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) h0_indices = cache["h0_indices"][device_key] cu_seqlens = cache["cu_seqlens"][device_key]
2419-2423: Potential issue: Contiguity check may not cover all copy scenarios.The condition checks
initial_state.is_contiguous()buth0_sourceis derived via.to(torch.float32).reshape(...). While.to()returns self when dtype matches, this relies on implementation details.Consider using the identity check pattern (like the nontranspose version at line 1838) for consistency:
♻️ Suggested fix
- if not disable_state_update and not initial_state.is_contiguous(): + # Create reshaped view/copy of initial_state + h0_flat = initial_state.reshape(pool_size * HV, V, K) + h0_source = h0_flat.contiguous() if not h0_flat.is_contiguous() else h0_flat + + # ... later after kernel execution ... + if not disable_state_update and h0_source.data_ptr() != initial_state.data_ptr(): initial_state.copy_(h0_source.reshape(pool_size, HV, V, K))Alternatively, verify that the current logic works correctly in all cases.
benchmarks/bench_gdn_decode.py (2)
219-227: Optional: Remove unusedBparameter from Triton kernels.The
Bparameter is declared astl.constexprbut not used in the kernel body. The batch index is derived fromi_bh // HVand bounds checking usesK_DIMandV_DIM.♻️ Suggested fix
Remove
B: tl.constexpr,from the kernel signature and the corresponding argument at call sites. This applies to all three Triton kernels:fused_sigmoid_gating_delta_rule_kernel,fused_sigmoid_gating_delta_rule_mtp_kernel, andfused_sigmoid_gating_delta_rule_kernel_pretranspose.
1346-1355: Consider renamingspeedupfor clarity.The current calculation
triton_median_us / flashinfer_median_usmeans "how many times slower Triton is" rather than "how much faster FlashInfer is". While the comment at line 2049 clarifies this, consider renaming toflashinfer_speedupor inverting the ratio for more intuitive interpretation.♻️ Alternative naming
# Option 1: Rename for clarity flashinfer_speedup = triton_median_us / flashinfer_median_us # >1 means FI faster # Option 2: Use ratio that matches conventional speedup semantics speedup = flashinfer_median_us / triton_median_us # >1 means Triton faster # Then update print statement: "Speedup > 1.0 means Triton is faster"
| status = "PASS" if passed else "FAIL" | ||
| print(f"Batch={batch_size}: {status}") | ||
| except Exception as e: | ||
| print(f"Batch={batch_size}: ERROR - {type(e).__name__}") |
There was a problem hiding this comment.
Catch specific exceptions instead of broad Exception.
Catching Exception can mask unexpected errors. Consider catching the specific exceptions that verification can raise:
🔧 Suggested fix
- except Exception as e:
- print(f"Batch={batch_size}: ERROR - {type(e).__name__}")
+ except (RuntimeError, torch.cuda.CudaError) as e:
+ print(f"Batch={batch_size}: ERROR - {type(e).__name__}: {e}")🧰 Tools
🪛 Ruff (0.14.14)
1928-1928: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In `@benchmarks/bench_gdn_decode.py` around lines 1926 - 1929, The current
try/except around the batch verification (the block that sets status = "PASS" if
passed else "FAIL" and prints Batch={batch_size}) should not catch broad
Exception; replace the generic "except Exception as e" with explicit handlers
for the expected verification errors (e.g., except AssertionError as e, except
ValueError as e, except RuntimeError as e) and handle each by printing
Batch={batch_size}: ERROR - <type>, and keep an optional final except: raise to
avoid swallowing unknown errors; reference the variables/status logic in this
try/except block (status, passed, batch_size) when updating the handlers.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gdn_decode.py (1)
238-244: Fix v-indexing: sV preload only covers first 128 elements.
sVis filled with indices0..127regardless ofstart_v_tilesorV. ForV > 128or for small-batch blocks withbatch_inner > 0,v_tiles * TILE_V + ...will read uninitialized values. Loadvfor the current tile or readvdirectly in the loop.🐛 Proposed fix (read v directly per tile)
- v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk + v_idx = v_tiles * TILE_V + row + row_offset + v_new = cutlass.Float32(v[i_n, i_t, i_hv, v_idx]) - sum_hk- v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk + v_idx = v_tiles * TILE_V + row + row_offset + v_new = cutlass.Float32(v[i_n, i_t, i_hv, v_idx]) - sum_hkAlso applies to: 346-347, 474-479, 581-582
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 105-114: The get_vec_size_mtp function currently chooses vec_size
solely from batch_size which can yield threads_per_group = K / vec_size > 32
when K > 128 and batch_size <= 4, causing groups_per_warp to become zero and
divide-by-zero downstream; update get_vec_size_mtp to consider K as well (either
accept K as an additional parameter or validate K before returning vec_size) and
select vec_size=8 when K/4 > 32 (or assert/raise a clear error if K is
incompatible), and apply the same guard/selection logic to the other occurrences
noted (around lines referenced: the other get_vec_size_mtp usages at 1917-1923,
2060-2062, 2333-2336) so threads_per_group never exceeds 32.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)
2241-2257: Silence unused-arg lint for cache-only parameters.
Ruff flagstile_v/vec_sizeas unused. Consider explicitly marking them as cache-key-only to avoid ARG001 noise.♻️ Proposed fix
def _get_compiled_mtp_kernel( B: int, T: int, H: int, @@ tile_v: int, # TILE_V - configurable for batch size vec_size: int, # 4 for full warp, 8 for half-warp ): """Cache compiled MTP kernel for given configuration.""" + _ = (tile_v, vec_size) # used by functools.cache key return {}
flashinfer/gdn_decode.py
Outdated
| def get_vec_size_mtp(batch_size: int) -> int: | ||
| """Select vec_size for MTP kernel based on batch size. | ||
|
|
||
| B <= 4: vec_size=4 (full warp reduction, 5 shuffles) - better for small batch | ||
| B > 4: vec_size=8 (half-warp reduction, 4 shuffles) - better for large batch | ||
| """ | ||
| if batch_size <= 4: | ||
| return 4 # Full warp: 32 threads * 4 elements = 128 | ||
| else: | ||
| return 8 # Half-warp: 16 threads * 8 elements = 128 |
There was a problem hiding this comment.
Guard MTP vec_size against K > 128 in small batch.
With B <= 4, vec_size=4 ⇒ threads_per_group = K/4. For K > 128, this exceeds 32, making groups_per_warp = 0 and rows_per_group = tile_v // 0 invalid. Either assert K compatibility for this path or choose vec_size based on K as well.
🐛 Proposed fix (explicit guard)
- vec_size = get_vec_size_mtp(B)
+ vec_size = get_vec_size_mtp(B)
+ assert K // vec_size <= 32, (
+ "K too large for selected vec_size; adjust vec_size selection or constrain K."
+ )Also applies to: 1917-1923, 2060-2062, 2333-2336
🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 105 - 114, The get_vec_size_mtp
function currently chooses vec_size solely from batch_size which can yield
threads_per_group = K / vec_size > 32 when K > 128 and batch_size <= 4, causing
groups_per_warp to become zero and divide-by-zero downstream; update
get_vec_size_mtp to consider K as well (either accept K as an additional
parameter or validate K before returning vec_size) and select vec_size=8 when
K/4 > 32 (or assert/raise a clear error if K is incompatible), and apply the
same guard/selection logic to the other occurrences noted (around lines
referenced: the other get_vec_size_mtp usages at 1917-1923, 2060-2062,
2333-2336) so threads_per_group never exceeds 32.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2463-2472: The code ignores --compare for non‑MTP runs because the
version branch always calls run_all_layouts_benchmark; move or add the compare
check so args.compare is evaluated first (or added into the non‑MTP branch).
Specifically, before calling run_all_layouts_benchmark, check if args.compare
and call run_comparison_benchmark(args, dtype, use_qk_l2norm); otherwise
preserve the existing behavior (for MTP: run_comparison_benchmark vs
run_flashinfer_only_benchmark; for non‑MTP: run_all_layouts_benchmark only when
not comparing). Update the logic around the run_all_layouts_benchmark,
run_comparison_benchmark, and run_flashinfer_only_benchmark calls accordingly so
--compare affects both MTP and non‑MTP paths.
🧹 Nitpick comments (1)
benchmarks/bench_gdn_decode.py (1)
2062-2069: Summary speedup can mismatch batches when some results are missing.
fi_pre_timesandtr_pre_timesare filtered independently, so the zipped list can pair different batches if any entry is missing on one side.♻️ Suggested adjustment
- fi_pre_times = [r["fi_pretrans_us"] for r in all_results if r.get("fi_pretrans_us")] - tr_pre_times = [r["tr_pretrans_us"] for r in all_results if r.get("tr_pretrans_us")] - - if fi_pre_times and tr_pre_times: - speedups = [tr / fi for fi, tr in zip(fi_pre_times, tr_pre_times, strict=False)] + speedups = [ + r["tr_pretrans_us"] / r["fi_pretrans_us"] + for r in all_results + if r.get("fi_pretrans_us") and r.get("tr_pretrans_us") + ] + + if speedups: print( f"FlashInfer vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" )
| if args.version == "mtp": | ||
| # MTP mode: use comparison or flashinfer-only | ||
| if args.compare: | ||
| run_comparison_benchmark(args, dtype, use_qk_l2norm) | ||
| else: | ||
| run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) | ||
| else: | ||
| # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) | ||
| run_all_layouts_benchmark(args, dtype, use_qk_l2norm) | ||
|
|
There was a problem hiding this comment.
--compare is ignored for non‑MTP paths.
For non‑MTP runs, main() always calls run_all_layouts_benchmark, so --compare (and single‑layout intent in the usage text) never takes effect. This is a user‑visible behavior bug.
🛠️ Proposed fix
- if args.version == "mtp":
- # MTP mode: use comparison or flashinfer-only
- if args.compare:
- run_comparison_benchmark(args, dtype, use_qk_l2norm)
- else:
- run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
- else:
- # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose)
- run_all_layouts_benchmark(args, dtype, use_qk_l2norm)
+ if args.version == "mtp":
+ # MTP mode: use comparison or flashinfer-only
+ if args.compare:
+ run_comparison_benchmark(args, dtype, use_qk_l2norm)
+ else:
+ run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm)
+ else:
+ # Non-MTP
+ if args.compare and args.version != "all":
+ run_comparison_benchmark(args, dtype, use_qk_l2norm)
+ else:
+ run_all_layouts_benchmark(args, dtype, use_qk_l2norm)🤖 Prompt for AI Agents
In `@benchmarks/bench_gdn_decode.py` around lines 2463 - 2472, The code ignores
--compare for non‑MTP runs because the version branch always calls
run_all_layouts_benchmark; move or add the compare check so args.compare is
evaluated first (or added into the non‑MTP branch). Specifically, before calling
run_all_layouts_benchmark, check if args.compare and call
run_comparison_benchmark(args, dtype, use_qk_l2norm); otherwise preserve the
existing behavior (for MTP: run_comparison_benchmark vs
run_flashinfer_only_benchmark; for non‑MTP: run_all_layouts_benchmark only when
not comparing). Update the logic around the run_all_layouts_benchmark,
run_comparison_benchmark, and run_flashinfer_only_benchmark calls accordingly so
--compare affects both MTP and non‑MTP paths.
|
The precision is fine compared to the Triton reference. Performance improves by approximately 20%~40% over the Triton kernel at large batch sizes. However, for small batch sizes, the Triton reference should use BV=8.
|
- Change from interleaved to contiguous memory access pattern - Use cute.local_tile + cute.autovec_copy for vectorized memory operations instead of scalar for-loops - Set vec_size=8 (half-warp, 8 groups) for all batch sizes - Tune tile_v: 8/16/32/64 based on batch size thresholds This achieves 1.06x-1.45x speedup over Triton baseline (avg 1.17x). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use 3D local_tile directly instead of slice + 1D local_tile - Cleaner code without intermediate tensor creation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace scalar loop with autovec_copy for coalesced vectorized loads - Load into BF16 registers first, then convert to FP32 - Improves average speedup from 1.18x to 1.22x over Triton Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Change q, k, v loading from interleaved to contiguous pattern - Use autovec_copy for vectorized loads (BF16 -> FP32 conversion) - Use local_tile + autovec_copy for h read/write in mainloop - Applies to both small_batch and big_batch pretranspose kernels Achieves 1.4x-1.6x speedup over Triton pretranspose baseline. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Always use vec_size=4 (32 threads/group = full warp) instead of vec_size=8 - Full warp shuffle is more efficient than half-warp shuffle - Tune tile_v per batch size via grid search: B≤2: tile_v=4, B≤4: tile_v=8, B≤8: tile_v=16, B≤16: tile_v=32, B>16: tile_v=64 - Remove dead code for vec_size=8 shuffle branches - Achieves >= 1.04x speedup vs Triton across all batch sizes (avg 1.14x) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/gdn_decode.py (2)
728-728: Dead code: expression computed but result discarded.This expression computes a value (likely data size in MB) but doesn't assign it anywhere. Remove it or assign to a variable if needed for debugging.
🧹 Proposed fix
num_v_tiles = cute.ceil_div(v_dim, TILE_V) - v_dim * k_dim * batch_size * 4 / 1024 / 1024 vec_size = (
831-831: Dead code: another unused expression.Same issue as line 728. This and the commented print statements below should be cleaned up.
🧹 Proposed fix
num_v_tiles = cute.ceil_div(v_dim, TILE_V) - v_dim * k_dim * batch_size * 4 / 1024 / 1024 vec_size = (
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1603-1606: The stray expression h0_indices.layout.shape[0] is
evaluated but not used; either remove this no-op or assign it to a meaningful
variable (e.g., hv_dim or use it to compute batch_size) so the value is
consumed. Locate the block where h0_source.layout.shape unpacks into
batch_hv_dim, k_dim, v_dim and the subsequent h0_indices.layout.shape[0]
expression, and either delete that expression or replace it with an assignment
that is actually used by downstream logic (same fix also apply to the analogous
occurrence around the h0_* code at lines corresponding to 1685-1687). Ensure
variable names (h0_indices, h0_source, batch_hv_dim, batch_size) are consistent
after the change.
- Around line 2416-2420: Add an upper-bound validation for K to fail fast when
the MTP kernel's warp-based grouping is violated: in the same validation block
that contains the asserts for K and V (referencing symbols K, V, and tile_v in
gdn_decode.py) add an assertion that K <= 128 with a clear error message like "K
must be at most 128, got K={K}" so the code raises a descriptive exception
instead of producing cryptic kernel failures.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)
1997-2000: Thread grouping assumes K ≤ 128; document or assert this constraint.The calculation
threads_per_group = K // vec_sizeassumes the result fits within a warp (≤32 threads). Withvec_size=4(always returned byget_vec_size_mtp), this requires K ≤ 128. Consider adding a comment or assert to make this constraint explicit at the kernel level.
| # h0_source is flattened to [B*HV, K, V] to ensure proper alignment for SIMT async copy | ||
| batch_hv_dim, k_dim, v_dim = h0_source.layout.shape | ||
| h0_indices.layout.shape[0] | ||
| batch_size = batch_hv_dim # batch_hv_dim = B * HV |
There was a problem hiding this comment.
Unused expression: h0_indices.layout.shape[0] evaluated but not assigned.
This statement has no effect. Either remove it or assign to a variable if the value is needed.
🧹 Proposed fix
# h0_source is flattened to [B*HV, K, V] to ensure proper alignment for SIMT async copy
- batch_hv_dim, k_dim, v_dim = h0_source.layout.shape
- h0_indices.layout.shape[0]
- batch_size = batch_hv_dim # batch_hv_dim = B * HV
+ batch_hv_dim, _k_dim, v_dim = h0_source.layout.shape
+ batch_size = batch_hv_dim # batch_hv_dim = B * HVAlso applies to lines 1685-1687.
🧰 Tools
🪛 Ruff (0.14.14)
1604-1604: Unpacked variable k_dim is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 1603 - 1606, The stray expression
h0_indices.layout.shape[0] is evaluated but not used; either remove this no-op
or assign it to a meaningful variable (e.g., hv_dim or use it to compute
batch_size) so the value is consumed. Locate the block where
h0_source.layout.shape unpacks into batch_hv_dim, k_dim, v_dim and the
subsequent h0_indices.layout.shape[0] expression, and either delete that
expression or replace it with an assignment that is actually used by downstream
logic (same fix also apply to the analogous occurrence around the h0_* code at
lines corresponding to 1685-1687). Ensure variable names (h0_indices, h0_source,
batch_hv_dim, batch_size) are consistent after the change.
| assert K >= 128, f"K must be at least 128, got K={K}" | ||
| assert V >= 128, f"V must be at least 128, got V={V}" | ||
| assert V % TILE_V_MTP == 0, ( | ||
| f"V must be divisible by {TILE_V_MTP} to prevent out-of-bounds access, got V={V}" | ||
| assert V % tile_v == 0, ( | ||
| f"V must be divisible by {tile_v} to prevent out-of-bounds access, got V={V}" | ||
| ) |
There was a problem hiding this comment.
Missing K upper bound validation for MTP kernel.
The validation checks K >= 128 but the kernel requires K <= 128 due to the warp-based thread grouping. Add an upper bound check to fail fast with a clear error message instead of cryptic kernel failures.
🐛 Proposed fix
# Validate K and V constraints
assert K >= 128, f"K must be at least 128, got K={K}"
+ assert K == 128, f"MTP kernel currently requires K=128 (warp thread grouping constraint), got K={K}"
assert V >= 128, f"V must be at least 128, got V={V}"
assert V % tile_v == 0, (
f"V must be divisible by {tile_v} to prevent out-of-bounds access, got V={V}"
)🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 2416 - 2420, Add an upper-bound
validation for K to fail fast when the MTP kernel's warp-based grouping is
violated: in the same validation block that contains the asserts for K and V
(referencing symbols K, V, and tile_v in gdn_decode.py) add an assertion that K
<= 128 with a clear error message like "K must be at most 128, got K={K}" so the
code raises a descriptive exception instead of producing cryptic kernel
failures.
|
@xutizhou the performance gap for B=4 should be fixed in most recent commit. |
|
/bot run |
|
[FAILED] Pipeline #42706808: 11/20 passed |
|
[like] Xuting ZHOU reacted to your message:
…________________________________
From: Zihao Ye ***@***.***>
Sent: Wednesday, January 28, 2026 5:36:08 AM
To: flashinfer-ai/flashinfer ***@***.***>
Cc: Xuting ZHOU ***@***.***>; Mention ***@***.***>
Subject: Re: [flashinfer-ai/flashinfer] perf: improve gdn decode cute-dsl kernels (PR #2405)
[https://avatars.githubusercontent.com/u/11773619?s=20&v=4]yzh119 left a comment (flashinfer-ai/flashinfer#2405)<#2405 (comment)>
@xutizhou<https://github.com/xutizhou> the performance gap for B=4 should be fixed in most recent commit.
—
Reply to this email directly, view it on GitHub<#2405 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/BMG7RSF3GJNDVEA7BUIKGRT4JBDERAVCNFSM6AAAAACSTLFU4CVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTQMBZGA4TKOBTGE>.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
<!-- .github/pull_request_template.md --> ## 📌 Description Follow up of flashinfer-ai#2370 , this PR improves the benchmark scripts and add comparison with baselines: * benchmark using cupti with l2 flush * compare with sglang's `fused_sigmoid_gating_delta_rule_update` function (with tile size optimization mentioned by @ vadiklyutiy). this PR also implements some optimizations on the original gdn kernel: * use fastmath as much as we can * change "/" to multiply * Use `cutlass.range_constexpr` and `cutlass.const_expr` whenever possible * fuse scale and inv_norm_q * For mtp, store state in registers directly, without load/write to shared memory, and remove cpasync * Vectorized memory access. ## Performance on B200 Non MTP setting ``` > python benchmarks/bench_gdn_decode.py --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify === Correctness Verification === Batch=8: Pretranspose: PASS Nontranspose: PASS Batch=16: Pretranspose: PASS Nontranspose: PASS Batch=32: Pretranspose: PASS Nontranspose: PASS Batch=64: Pretranspose: PASS Nontranspose: PASS ======================================================================================================================== GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON ======================================================================================================================== batch | FI-PreTr FI-NonTr | TR-PreTr TR-NonTr | FI/TR-Pre FI/TR-Non | Pre/Non-FI Pre/Non-TR | (us) (us) | (us) (us) | speedup speedup | speedup speedup ------------------------------------------------------------------------------------------------------------------------ 1 | 3.74 5.06 | 5.95 4.35 | 1.59x 0.86x | 1.35x 0.73x 2 | 4.29 5.89 | 6.37 5.02 | 1.49x 0.85x | 1.37x 0.79x 4 | 5.41 7.78 | 7.58 6.66 | 1.40x 0.86x | 1.44x 0.88x 8 | 7.65 12.03 | 9.95 10.21 | 1.30x 0.85x | 1.57x 1.03x 16 | 12.61 19.30 | 16.83 15.81 | 1.34x 0.82x | 1.53x 0.94x 32 | 22.91 32.86 | 31.55 27.84 | 1.38x 0.85x | 1.43x 0.88x 64 | 52.74 58.61 | 58.91 53.02 | 1.12x 0.90x | 1.11x 0.90x 128 | 92.93 107.98 | 114.45 106.78 | 1.23x 0.99x | 1.16x 0.93x 256 | 170.77 209.04 | 225.71 216.41 | 1.32x 1.04x | 1.22x 0.96x ------------------------------------------------------------------------------------------------------------------------ Legend: FI-PreTr = FlashInfer Pretranspose [B, HV, V, K] FI-NonTr = FlashInfer Nontranspose [B, HV, K, V] TR-PreTr = Triton Pretranspose [B, HV, V, K] TR-NonTr = Triton Nontranspose [B, HV, K, V] FI/TR speedup > 1.0 means FlashInfer is faster than Triton Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose FlashInfer vs Triton (Pretranspose) - Average speedup: 1.35x ``` MTP Setting (pretranspose only) ``` > python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 2 4 8 16 32 64 128 256 --compare --verify === Correctness Verification (MTP) === Batch=8: PASS Batch=16: PASS Batch=32: PASS Batch=64: PASS GDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton Config: q_heads=16, k_heads=16, v_heads=32, head_size=128, dtype=bfloat16, qk_l2norm=ON, cache_intermediate=OFF -------------------------------------------------------------------------------------------------------------- batch seq_len FlashInfer(us) Triton(us) FI TFLOPS TR TFLOPS Speedup -------------------------------------------------------------------------------------------------------------- 1 2 9.22 10.05 0.68 0.63 1.09x 1 4 11.20 14.43 1.12 0.87 1.29x 1 8 15.81 22.08 1.59 1.14 1.40x 2 2 10.11 10.69 1.24 1.18 1.06x 2 4 12.58 15.10 2.00 1.67 1.20x 2 8 18.82 23.63 2.67 2.13 1.26x 4 2 11.39 11.94 2.21 2.11 1.05x 4 4 15.23 16.54 3.30 3.04 1.09x 4 8 23.62 25.50 4.26 3.95 1.08x 8 2 14.69 17.23 3.43 2.92 1.17x 8 4 21.20 25.01 4.75 4.03 1.18x 8 8 34.69 40.86 5.80 4.93 1.18x 16 2 21.47 24.22 4.69 4.16 1.13x 16 4 32.54 36.98 6.19 5.44 1.14x 16 8 56.24 61.76 7.16 6.52 1.10x 32 2 33.50 37.68 6.01 5.34 1.12x 32 4 54.66 60.26 7.37 6.68 1.10x 32 8 97.98 104.35 8.22 7.72 1.06x 64 2 59.82 65.38 6.73 6.16 1.09x 64 4 102.05 108.83 7.89 7.40 1.07x 64 8 188.17 196.45 8.56 8.20 1.04x 128 2 107.44 121.41 7.50 6.63 1.13x 128 4 192.01 209.90 8.39 7.67 1.09x 128 8 366.81 389.12 8.78 8.28 1.06x 256 2 199.14 236.19 8.09 6.82 1.19x 256 4 363.36 422.61 8.87 7.62 1.16x 256 8 708.22 787.05 9.10 8.19 1.11x -------------------------------------------------------------------------------------------------------------- Speedup > 1.0 means FlashInfer is faster Summary: Average speedup: 1.13x Min speedup: 1.04x (batch=64, T=8) Max speedup: 1.40x (batch=1, T=8) ``` ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Triton-based benchmarks and end-to-end comparison/verify modes across multiple memory layouts (including MTP); new verification flows to compare implementations. * **Performance Improvements** * Batch-size-aware kernel selection, configurable tile/vec sizing, fast-math paths, reduced redundant copies, and CUPTI-backed GPU timing for more accurate benchmarks. * **Behavior & Compatibility** * Improved layout handling, expanded CLI presets/modes, clearer error messages and guards when Triton is unavailable; default benchmark mode updated. * **Documentation** * Updated usage examples and CLI guidance. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: HongliMi <1667738261@qq.com> Co-authored-by: Hongli Mi <hmi@nvidia.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
ameynaik-hub
left a comment
There was a problem hiding this comment.
also cache intermediate states should be True since it is required MTP.
| @@ -425,52 +1213,14 @@ def bench_gdn_mtp( | |||
| intermediate_states_buffer, | |||
| disable_state_update=True, | |||
There was a problem hiding this comment.
shouldnt this be False for benchmarking? because we want the updated state to be an output in mtp.
| @@ -425,52 +1213,14 @@ def bench_gdn_mtp( | |||
| intermediate_states_buffer, | |||
| disable_state_update=True, | |||
There was a problem hiding this comment.
@yzh119 I think this should be False, we want h updated as output.
📌 Description
Follow up of #2370 , this PR improves the benchmark scripts and add comparison with baselines:
fused_sigmoid_gating_delta_rule_updatefunction (with tile size optimization mentioned by @ vadiklyutiy).this PR also implements some optimizations on the original gdn kernel:
cutlass.range_constexprandcutlass.const_exprwhenever possiblePerformance on B200
Non MTP setting
MTP Setting (pretranspose only)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.