Skip to content

KV Split Oversubscription for Mixed Sequence Lengths#3379

Open
PatrykSaffer wants to merge 9 commits into
flashinfer-ai:mainfrom
PatrykSaffer:patryk/mixed-seq-len-multiwave
Open

KV Split Oversubscription for Mixed Sequence Lengths#3379
PatrykSaffer wants to merge 9 commits into
flashinfer-ai:mainfrom
PatrykSaffer:patryk/mixed-seq-len-multiwave

Conversation

@PatrykSaffer
Copy link
Copy Markdown

@PatrykSaffer PatrykSaffer commented May 21, 2026

📌 Description

When a batch contains mixed sequence lengths, the existing heuristic in computeCtaAndClusterConfig assigns numCtasPerSeqKv based on filling one wave of SM occupancy. This under-splits long sequences.

This PR allows SM oversubscription — splitting KV beyond one wave when max_seq_len demands it. Short sequences' extra CTAs early-exit. Three tunable constants control the behavior: kMinTokensPerCta, kMaxOccupancyWaves, and kMaxSplits.

The biggest gains are visible when the distribution is close to prod-like, for that I added distribution [4096, 8192, 16384, 32768, 65536, 131072] with weights [35, 20, 15, 12, 10, 8] and benchmarks based on that.

Benchmarks for DSv3 Running TP8

Screenshot 2026-05-20 at 13 18 56

constants tunning:

Screenshot 2026-05-20 at 13 19 21

Results from benchmarks/bench_trtllm_gen_mla.py

random context len:

batch_size q_len max_seq_len Baseline (ms) Fix (ms) Speedup
1 1 1024 0.0106 0.0106 1.00x
1 1 4096 0.0167 0.0167 1.00x
1 1 8192 0.0227 0.0227 1.00x
1 1 16384 0.0356 0.0356 1.00x
1 1 32768 0.0274 0.0274 1.00x
1 1 65536 0.0325 0.0325 1.00x
1 4 1024 0.0143 0.0143 1.00x
1 4 4096 0.0300 0.0300 1.00x
1 4 8192 0.0190 0.0190 1.00x
1 4 16384 0.0242 0.0241 1.00x
1 4 32768 0.0331 0.0331 1.00x
1 4 65536 0.0604 0.0712 0.85x
4 1 1024 0.0132 0.0132 1.00x
4 1 4096 0.0264 0.0264 1.00x
4 1 8192 0.0191 0.0191 1.00x
4 1 16384 0.0235 0.0235 1.00x
4 1 32768 0.0388 0.0389 1.00x
4 1 65536 0.0631 0.0618 1.02x
4 4 1024 0.0233 0.0233 1.00x
4 4 4096 0.0218 0.0218 1.00x
4 4 8192 0.0316 0.0316 1.00x
4 4 16384 0.0516 0.0530 0.97x
4 4 32768 0.0997 0.0968 1.03x
4 4 65536 0.1870 0.1839 1.02x
16 1 1024 0.0229 0.0229 1.00x
16 1 4096 0.0214 0.0214 1.00x
16 1 8192 0.0329 0.0329 1.00x
16 1 16384 0.0629 0.0645 0.98x
16 1 32768 0.1168 0.1266 0.92x
16 1 65536 0.2311 0.2251 1.03x
16 4 1024 0.0187 0.0187 1.00x
16 4 4096 0.0484 0.0484 1.00x
16 4 8192 0.0937 0.0937 1.00x
16 4 16384 0.1846 0.1649 1.12x
16 4 32768 0.3809 0.4039 0.94x
16 4 65536 0.7763 0.6665 1.16x
32 1 1024 0.0175 0.0175 1.00x
32 1 4096 0.0360 0.0360 1.00x
32 1 8192 0.0638 0.0638 1.00x
32 1 16384 0.1143 0.1039 1.10x
32 1 32768 0.2263 0.2194 1.03x
32 1 65536 0.4245 0.3999 1.06x
32 4 1024 0.0302 0.0302 1.00x
32 4 4096 0.0951 0.0951 1.00x
32 4 8192 0.1724 0.1724 1.00x
32 4 16384 0.3510 0.3442 1.02x
32 4 32768 0.7586 0.6754 1.12x
32 4 65536 1.4970 1.0700 1.40x
64 1 1024 0.0186 0.0186 1.00x
64 1 4096 0.0602 0.0602 1.00x
64 1 8192 0.1150 0.1150 1.00x
64 1 16384 0.2211 0.2016 1.10x
64 1 32768 0.4381 0.4312 1.02x
64 1 65536 0.7947 0.6026 1.32x
64 4 1024 0.0503 0.0503 1.00x
64 4 4096 0.1469 0.1469 1.00x
64 4 8192 0.3045 0.3045 1.00x
64 4 16384 0.6482 0.6104 1.06x
64 4 32768 1.3375 1.1416 1.17x
64 4 65536 2.2376 1.6997 1.32x
128 1 1024 0.0368 0.0368 1.00x
128 1 4096 0.1096 0.1096 1.00x
128 1 8192 0.2161 0.2161 1.00x
128 1 16384 0.4302 0.4162 1.03x
128 1 32768 0.8667 0.7828 1.11x
128 1 65536 1.5258 1.1449 1.33x
128 4 1024 0.0874 0.0874 1.00x
128 4 4096 0.2725 0.2725 1.00x
128 4 8192 0.5879 0.5879 1.00x
128 4 16384 1.2347 1.1466 1.08x
128 4 32768 2.5940 2.2304 1.16x
128 4 65536 4.1211 3.4265 1.20x
256 1 1024 0.0707 0.0707 1.00x
256 1 4096 0.1684 0.1684 1.00x
256 1 8192 0.3587 0.3587 1.00x
256 1 16384 0.6843 0.6994 0.98x
256 1 32768 1.4159 1.3482 1.05x
256 1 65536 2.6901 2.4722 1.09x
256 4 1024 0.1651 0.1651 1.00x
256 4 4096 0.5126 0.5126 1.00x
256 4 8192 1.1125 1.1125 1.00x
256 4 16384 2.2097 2.2193 1.00x
256 4 32768 4.7050 4.7077 1.00x
256 4 65536 8.9918 8.9955 1.00x
512 1 1024 0.1091 0.1091 1.00x
512 1 4096 0.3125 0.3125 1.00x
512 1 8192 0.6592 0.6592 1.00x
512 1 16384 1.2598 1.2885 0.98x
512 1 32768 2.5687 2.5332 1.01x
512 1 65536 4.9758 4.8236 1.03x
512 4 1024 0.3522 0.3522 1.00x
512 4 4096 1.0153 1.0153 1.00x
512 4 8192 2.1773 2.1773 1.00x
512 4 16384 4.3404 4.3469 1.00x
512 4 32768 9.1153 9.0990 1.00x
512 4 65536 18.0686 18.0088 1.00x
1024 1 1024 0.1921 0.1921 1.00x
1024 1 4096 0.6320 0.6320 1.00x
1024 1 8192 1.2372 1.2372 1.00x
1024 1 16384 2.4047 2.3840 1.01x
1024 1 32768 4.9105 4.9226 1.00x
1024 4 1024 0.7000 0.7000 1.00x
1024 4 4096 2.0706 2.0706 1.00x
1024 4 8192 4.3104 4.3104 1.00x
1024 4 16384 8.5939 8.5827 1.00x
1024 4 32768 18.0348 18.0152 1.00x

Context length distribution: [4096, 8192, 16384, 32768, 65536, 131072] with weights [35, 20, 15, 12, 10, 8].

batch_size q_len Baseline (ms) Fix (ms) Speedup
16 1 0.3737 0.1461 2.56x
16 4 1.3655 0.3790 3.60x
32 1 0.7365 0.3475 2.12x
32 4 2.2080 0.9370 2.36x
64 1 1.4719 0.5618 2.62x
64 4 2.9222 1.5342 1.90x
128 1 2.1495 1.0261 2.09x
128 4 4.0029 3.0739 1.30x
256 1 3.2009 2.1288 1.50x
256 4 7.8784 7.8717 1.00x

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Summary by CodeRabbit

  • Bug Fixes

    • Ensure at least one reduction task is launched to avoid zero-CTA cases.
  • Performance Improvements

    • Increased decode kernel workspace to handle larger workloads.
    • Improved CTA/occupancy heuristics and workspace sizing for better throughput on mixed-length batches.
    • Added environment tunables to control KV-split oversubscription behavior.
  • New Features

    • Benchmarks support mixed per-request sequence lengths, include a production-like length sampler, and run focused mixed-length workload scenarios.

Review Change Stack

root added 3 commits May 21, 2026 09:34
Signed-off-by: root <root@slurm-eus-04a-prod-b200-192-107.slurm-eus-04a-prod-compute.tenant-slurm.svc.cluster.local>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 21, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Benchmarks accept per-request mixed sequence lengths and run mixed-length workloads; FMHA kernel CTA/cluster heuristics and multi-CTA workspace sizing were adapted to use CtaLaunchParams-derived dimensions; reduction path now ensures at least one CTA is launched.

Changes

Mixed-Sequence MLA Workload Support

Layer / File(s) Summary
Mixed-sequence benchmark framework
benchmarks/bench_trtllm_gen_mla.py
bench_trtllm_mla gains optional seq_lens_list to supply per-request sequence lengths; per-item blocks_per_seq and max_seq_len are derived from the list; decode workspace_buffer increased to 1024 MB; sample_prod_distribution helper added; main runs mixed-length workloads for batch sizes [16,32,64,128,256].
Adaptive CTA and workspace sizing for variable sequences
include/flashinfer/trtllm/fmha/fmhaKernels.cuh, include/flashinfer/trtllm/fmha/kernelParams.h, include/flashinfer/trtllm/common.h
TllmGenFmhaKernel::run and computeCtaAndClusterConfig use a baseCtas product and optionally oversubscribe numCtasPerSeqKv when mMaxSeqLenKv implies extra KV splits, bounded by SM-derived caps and fixed limits. KernelParams::setKernelParams signature now accepts CtaLaunchParams and computes partialStatsBufferSize from CTA grid dims and kernelMeta.mStepQ; params.mMaxNumCtasQ/mMaxNumCtasKv are set from ctaLaunchParams. New env helpers expose KV-oversub tuning parameters.
CTA launch configuration minimum bound
csrc/fmhaReduction.cu
runFmhaReduction clamps numCtasForReduction with std::max(1, std::min(maxNumCtasForReduction, numSlices)) to ensure at least one reduction CTA is launched.

Sequence Diagram

sequenceDiagram
  participant Main as __main__
  participant Sampler as sample_prod_distribution
  participant Bench as bench_trtllm_mla
  participant Kernel as TllmGenFmhaKernel
  Main->>Sampler: batch_size, seed
  Sampler-->>Main: seq_lens_list
  Main->>Bench: bench_trtllm_mla(seq_lens_list=sl, seq_len=max(sl), batch_size, q_len_per_request=1, page_size=32)
  Bench->>Kernel: launch FMHA kernels with ctaLaunchParams derived from seq_lens
  Kernel->>Kernel: KernelParams::setKernelParams(params, kernelMeta, ctaLaunchParams)
  Kernel->>Kernel: computeCtaAndClusterConfig -> baseCtas, numCtasPerSeqKv (maybe oversubscribe)
  Kernel-->>Bench: kernel launch / execution
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci, model: dsv3.2, op: attention

Suggested reviewers

  • sricketts
  • aleozlx
  • yzh119

Poem

🐇 I sampled hops of varied length and cheer,
Mixed sequences bustling, buffers growing near,
CTAs now promise at least one little crew,
Kernels size their work and split where needed too,
A rabbit nods — the mixed-length run is here 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely describes the main technical change: adding KV split oversubscription to handle mixed sequence lengths, which is the core feature introduced in this changeset.
Description check ✅ Passed The PR description is comprehensive, includes detailed motivation, benchmark results with concrete speedup numbers, tuning constants explanation, and confirms pre-commit checks and tests passed. However, it lacks a dedicated 'Related Issues' section as specified in the template.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces mixed sequence length benchmarks using production-like distributions and optimizes KV parallelism for long sequences in mixed-length batches by oversubscribing SMs. Key changes include refining the partial buffer size calculation and ensuring a minimum number of CTAs for reduction. Feedback highlights that the workspace buffer size may still be insufficient for large batches, recommends optimizing memory allocation by using actual sequence lengths instead of the maximum for block calculations, and identifies a potential under-calculation in the partial buffer size logic that could lead to out-of-bounds memory access.

Comment thread benchmarks/bench_trtllm_gen_mla.py Outdated
Comment thread benchmarks/bench_trtllm_gen_mla.py Outdated
Comment thread include/flashinfer/trtllm/fmha/kernelParams.h Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Around line 34-40: Check and validate seq_lens_list before using it: in the
branch that reads seq_lens_list (affecting variables seq_lens and max_seq_len),
ensure seq_lens_list is not None, not empty, has length == batch_size, and
contains integer values in the range [1, seq_len]; if validation fails either
raise a clear ValueError (including batch_size and seq_len) or coerce/fill
missing entries (e.g., pad/truncate to batch_size and clamp values to [1,
seq_len]) before calling max(seq_lens) and downstream indexing.

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 832-836: partialStatsBufferSize is computed from
numCtasForAllHeads_ using kernelMeta.mStepQ but the kernel launch groups heads
differently (mGroupsHeadsQ / mNumHeadsQPerKv), so the buffer can be undersized
causing OOB for ptrPartialStats/ptrPartialO; update the calculation so
numCtasForAllHeads_ is derived from the same grouping used at launch (use
options.mGroupsHeadsQ or kernelMeta.mNumHeadsQPerKv / options.mNumHeadsQPerKv as
appropriate) instead of options.mNumHeadsQ / kernelMeta.mStepQ, then recompute
partialStatsBufferSize with that corrected numCtasForAllHeads_ to ensure full
allocation for all grouped head CTAs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c33f1a3a-5d52-467a-948e-e11e63ec47c0

📥 Commits

Reviewing files that changed from the base of the PR and between 3a81c3e and d612247.

📒 Files selected for processing (4)
  • benchmarks/bench_trtllm_gen_mla.py
  • csrc/fmhaReduction.cu
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h

Comment thread benchmarks/bench_trtllm_gen_mla.py
Comment thread include/flashinfer/trtllm/fmha/kernelParams.h Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

832-838: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add the divisibility check that was part of the prior suggested fix.

The numHeadsPerCta_ / numCtasForAllHeads_ derivation matches the previous proposal, but the divisibility guard was omitted. When kernelMeta.mGroupsHeadsQ is true and options.mNumHeadsQPerKv >= kernelMeta.mStepQ, the divisor on Line 834 becomes mStepQ, and options.mNumHeadsQ % mStepQ is not guaranteed to be zero. Silent truncation here under-allocates partialStatsBufferSize and can produce OOB writes/reads via ptrPartialStats / ptrPartialO.

🛡️ Proposed fix to add the divisibility guard
     int const numHeadsPerCta_ =
         kernelMeta.mGroupsHeadsQ ? std::min(options.mNumHeadsQPerKv, kernelMeta.mStepQ) : 1;
+    FLASHINFER_CHECK(options.mNumHeadsQ % numHeadsPerCta_ == 0,
+                     "numHeadsQ (%d) must be divisible by numHeadsPerCta (%d)",
+                     options.mNumHeadsQ, numHeadsPerCta_);
     int const numCtasForAllHeads_ = options.mNumHeadsQ / numHeadsPerCta_;
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 832 - 838, Ensure
the number of Q heads is divisible by heads-per-CTA before using it as a
divisor: when kernelMeta.mGroupsHeadsQ is true compute numHeadsPerCta_ as shown,
then assert or handle the case where options.mNumHeadsQ % numHeadsPerCta_ != 0
(e.g., return error, throw, or round up and document) before computing
numCtasForAllHeads_ and partialStatsBufferSize so you don't silently truncate
and under-allocate; update the code paths that use numCtasForAllHeads_,
partialStatsBufferSize, and pointers ptrPartialStats/ptrPartialO to rely only on
the validated/divisible value.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 832-838: Ensure the number of Q heads is divisible by
heads-per-CTA before using it as a divisor: when kernelMeta.mGroupsHeadsQ is
true compute numHeadsPerCta_ as shown, then assert or handle the case where
options.mNumHeadsQ % numHeadsPerCta_ != 0 (e.g., return error, throw, or round
up and document) before computing numCtasForAllHeads_ and partialStatsBufferSize
so you don't silently truncate and under-allocate; update the code paths that
use numCtasForAllHeads_, partialStatsBufferSize, and pointers
ptrPartialStats/ptrPartialO to rely only on the validated/divisible value.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f22251bc-3145-4157-a0a8-c2cc517ec89c

📥 Commits

Reviewing files that changed from the base of the PR and between d612247 and bb416d7.

📒 Files selected for processing (2)
  • benchmarks/bench_trtllm_gen_mla.py
  • include/flashinfer/trtllm/fmha/kernelParams.h

Comment thread include/flashinfer/trtllm/fmha/kernelParams.h
// When the longest sequence needs more KV splits than the standard
// heuristic provides, oversubscribe the SMs. This helps mixed-length batches
// where long sequences get insufficient KV parallelism.
int constexpr kMinTokensPerCta = 2048;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The three constant values are too specific and might impact other cases. I would make this as an optional and user-provided options (using ENVs or something).
or I am thinking that we can have a callback function or something so that the heuristic function can be passed.
@bkryu might have better suggestions.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added ENVs and disabled by default.

Comment thread include/flashinfer/trtllm/fmha/fmhaKernels.cuh Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
benchmarks/bench_trtllm_gen_mla.py (2)

123-132: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Enable CUPTI path in benchmark timing configuration.

Line 129 hard-disables CUPTI (enable_cupti=False), which bypasses the intended CUPTI-with-fallback benchmarking path.

Proposed fix
     measurements = bench_gpu_time(
         lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
             **common_kwargs
         ),
         dry_run_iters=5,
         repeat_iters=30,
-        enable_cupti=False,
+        enable_cupti=True,
         use_cuda_graph=True,
         cold_l2_cache=True,
     )

As per coding guidelines: "benchmarks/**/*.py: Use flashinfer.testing.bench_gpu_time() for benchmarking with CUPTI support and automatic fallback to CUDA events".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/bench_trtllm_gen_mla.py` around lines 123 - 132, The benchmark
currently disables CUPTI by passing enable_cupti=False to bench_gpu_time, which
prevents exercising the CUPTI-with-fallback path; update the bench call that
wraps flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the lambda passed
to bench_gpu_time) to set enable_cupti=True so bench_gpu_time can attempt CUPTI
and automatically fall back to CUDA events if unavailable, leaving other
parameters (dry_run_iters, repeat_iters, use_cuda_graph, cold_l2_cache)
unchanged.

42-47: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Don’t use assert for runtime input validation.

Line 43 can be compiled out with python -O; use explicit checks and reject non-positive sequence lengths to avoid downstream invalid block math.

Proposed fix
     if seq_lens_list is not None:
-        assert len(seq_lens_list) == batch_size, (
-            f"seq_lens_list length {len(seq_lens_list)} != batch_size {batch_size}"
-        )
+        if len(seq_lens_list) != batch_size:
+            raise ValueError(
+                f"seq_lens_list length {len(seq_lens_list)} != batch_size {batch_size}"
+            )
         seq_lens = list(seq_lens_list)
+        if any(s <= 0 for s in seq_lens):
+            raise ValueError("All sequence lengths in seq_lens_list must be > 0")
         max_seq_len = max(seq_lens)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/bench_trtllm_gen_mla.py` around lines 42 - 47, Replace the
assert-based check for seq_lens_list with explicit runtime validation: if
seq_lens_list is not None, verify len(seq_lens_list) == batch_size and raise a
ValueError with a clear message if not; convert seq_lens_list to seq_lens =
list(seq_lens_list) and validate every element is an int > 0, raising ValueError
for any non-positive or non-int entry; then compute max_seq_len = max(seq_lens)
as before. Use the identifiers seq_lens_list, batch_size, seq_lens, and
max_seq_len to locate and update the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Line 88: The workspace_buffer is allocated with torch.empty which leaves
uninitialized bytes and can corrupt the decode counter region; replace the
allocation with a zero-initialized buffer or explicitly zero it after creation
(e.g., use torch.zeros with same shape, dtype and device or call
workspace_buffer.zero_()) so the counter region is deterministic before first
kernel use; refer to workspace_buffer in the allocation site to apply this
change.

---

Outside diff comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Around line 123-132: The benchmark currently disables CUPTI by passing
enable_cupti=False to bench_gpu_time, which prevents exercising the
CUPTI-with-fallback path; update the bench call that wraps
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the lambda passed to
bench_gpu_time) to set enable_cupti=True so bench_gpu_time can attempt CUPTI and
automatically fall back to CUDA events if unavailable, leaving other parameters
(dry_run_iters, repeat_iters, use_cuda_graph, cold_l2_cache) unchanged.
- Around line 42-47: Replace the assert-based check for seq_lens_list with
explicit runtime validation: if seq_lens_list is not None, verify
len(seq_lens_list) == batch_size and raise a ValueError with a clear message if
not; convert seq_lens_list to seq_lens = list(seq_lens_list) and validate every
element is an int > 0, raising ValueError for any non-positive or non-int entry;
then compute max_seq_len = max(seq_lens) as before. Use the identifiers
seq_lens_list, batch_size, seq_lens, and max_seq_len to locate and update the
code.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3a5bb811-e27e-4a4f-9c4b-fa5e19b8decd

📥 Commits

Reviewing files that changed from the base of the PR and between 1207140 and ec03b0b.

📒 Files selected for processing (4)
  • benchmarks/bench_trtllm_gen_mla.py
  • csrc/fmhaReduction.cu
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h
💤 Files with no reviewable changes (2)
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh

# Allocate workspace buffer
# todo(Yingyi): calculate the actual size of workspace buffer
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Zero-initialize decode workspace before first kernel use.

Line 88 allocates workspace_buffer with torch.empty(...), but this path requires zeroed workspace for the counter region; uninitialized bytes can make results/timings flaky.

Proposed fix
-    workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
+    workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8, device=device)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/bench_trtllm_gen_mla.py` at line 88, The workspace_buffer is
allocated with torch.empty which leaves uninitialized bytes and can corrupt the
decode counter region; replace the allocation with a zero-initialized buffer or
explicitly zero it after creation (e.g., use torch.zeros with same shape, dtype
and device or call workspace_buffer.zero_()) so the counter region is
deterministic before first kernel use; refer to workspace_buffer in the
allocation site to apply this change.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@include/flashinfer/trtllm/common.h`:
- Around line 200-203: getIntEnv uses std::atoi which silently returns 0 for
invalid input leading to a possible division-by-zero downstream (see
fmhaKernels.cuh use). Replace or augment
getIntEnv/getEnvKvOversubMinTokensPerCta to validate the environment value:
parse with std::stoi inside a try/catch (or check that the env string is all
digits) and enforce a safe minimum (e.g., at least 1) before returning; on parse
failure or out-of-range values fall back to the provided defaultVal and log or
warn. Ensure you reference and update getIntEnv (and the specific getter
getEnvKvOversubMinTokensPerCta if present) so callers never receive zero from
invalid env strings.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7727c285-de88-48db-be17-a3ecb339f748

📥 Commits

Reviewing files that changed from the base of the PR and between ec03b0b and 531fff7.

📒 Files selected for processing (2)
  • include/flashinfer/trtllm/common.h
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Comment on lines +200 to +203
inline static int getIntEnv(char const* name, int defaultVal) {
char const* env = std::getenv(name);
return env ? std::atoi(env) : defaultVal;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

std::atoi silently returns 0 for invalid input, risking division by zero.

If FLASHINFER_KV_OVERSUB_MIN_TOKENS_PER_CTA is set to a non-numeric string (e.g., "abc"), atoi returns 0. This value is used as a divisor in fmhaKernels.cuh line 504, causing a crash.

Consider adding validation in getEnvKvOversubMinTokensPerCta() or using std::stoi with exception handling (consistent with csrc/nv_internal/cpp/common/envUtils.cpp).

🛡️ Proposed fix: add minimum bound in getter
 inline int getEnvKvOversubMinTokensPerCta() {
-  static int const val = getIntEnv("FLASHINFER_KV_OVERSUB_MIN_TOKENS_PER_CTA", 2048);
+  static int const val = std::max(1, getIntEnv("FLASHINFER_KV_OVERSUB_MIN_TOKENS_PER_CTA", 2048));
   return val;
 }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@include/flashinfer/trtllm/common.h` around lines 200 - 203, getIntEnv uses
std::atoi which silently returns 0 for invalid input leading to a possible
division-by-zero downstream (see fmhaKernels.cuh use). Replace or augment
getIntEnv/getEnvKvOversubMinTokensPerCta to validate the environment value:
parse with std::stoi inside a try/catch (or check that the env string is all
digits) and enforce a safe minimum (e.g., at least 1) before returning; on parse
failure or out-of-range values fall back to the provided defaultVal and log or
warn. Ensure you reference and update getIntEnv (and the specific getter
getEnvKvOversubMinTokensPerCta if present) so callers never receive zero from
invalid env strings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants