KV Split Oversubscription for Mixed Sequence Lengths#3379
Conversation
Signed-off-by: root <root@slurm-eus-04a-prod-b200-192-107.slurm-eus-04a-prod-compute.tenant-slurm.svc.cluster.local>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughBenchmarks accept per-request mixed sequence lengths and run mixed-length workloads; FMHA kernel CTA/cluster heuristics and multi-CTA workspace sizing were adapted to use CtaLaunchParams-derived dimensions; reduction path now ensures at least one CTA is launched. ChangesMixed-Sequence MLA Workload Support
Sequence DiagramsequenceDiagram
participant Main as __main__
participant Sampler as sample_prod_distribution
participant Bench as bench_trtllm_mla
participant Kernel as TllmGenFmhaKernel
Main->>Sampler: batch_size, seed
Sampler-->>Main: seq_lens_list
Main->>Bench: bench_trtllm_mla(seq_lens_list=sl, seq_len=max(sl), batch_size, q_len_per_request=1, page_size=32)
Bench->>Kernel: launch FMHA kernels with ctaLaunchParams derived from seq_lens
Kernel->>Kernel: KernelParams::setKernelParams(params, kernelMeta, ctaLaunchParams)
Kernel->>Kernel: computeCtaAndClusterConfig -> baseCtas, numCtasPerSeqKv (maybe oversubscribe)
Kernel-->>Bench: kernel launch / execution
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces mixed sequence length benchmarks using production-like distributions and optimizes KV parallelism for long sequences in mixed-length batches by oversubscribing SMs. Key changes include refining the partial buffer size calculation and ensuring a minimum number of CTAs for reduction. Feedback highlights that the workspace buffer size may still be insufficient for large batches, recommends optimizing memory allocation by using actual sequence lengths instead of the maximum for block calculations, and identifies a potential under-calculation in the partial buffer size logic that could lead to out-of-bounds memory access.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Around line 34-40: Check and validate seq_lens_list before using it: in the
branch that reads seq_lens_list (affecting variables seq_lens and max_seq_len),
ensure seq_lens_list is not None, not empty, has length == batch_size, and
contains integer values in the range [1, seq_len]; if validation fails either
raise a clear ValueError (including batch_size and seq_len) or coerce/fill
missing entries (e.g., pad/truncate to batch_size and clamp values to [1,
seq_len]) before calling max(seq_lens) and downstream indexing.
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 832-836: partialStatsBufferSize is computed from
numCtasForAllHeads_ using kernelMeta.mStepQ but the kernel launch groups heads
differently (mGroupsHeadsQ / mNumHeadsQPerKv), so the buffer can be undersized
causing OOB for ptrPartialStats/ptrPartialO; update the calculation so
numCtasForAllHeads_ is derived from the same grouping used at launch (use
options.mGroupsHeadsQ or kernelMeta.mNumHeadsQPerKv / options.mNumHeadsQPerKv as
appropriate) instead of options.mNumHeadsQ / kernelMeta.mStepQ, then recompute
partialStatsBufferSize with that corrected numCtasForAllHeads_ to ensure full
allocation for all grouped head CTAs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c33f1a3a-5d52-467a-948e-e11e63ec47c0
📒 Files selected for processing (4)
benchmarks/bench_trtllm_gen_mla.pycsrc/fmhaReduction.cuinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/kernelParams.h
There was a problem hiding this comment.
♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
832-838:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winAdd the divisibility check that was part of the prior suggested fix.
The
numHeadsPerCta_/numCtasForAllHeads_derivation matches the previous proposal, but the divisibility guard was omitted. WhenkernelMeta.mGroupsHeadsQis true andoptions.mNumHeadsQPerKv >= kernelMeta.mStepQ, the divisor on Line 834 becomesmStepQ, andoptions.mNumHeadsQ % mStepQis not guaranteed to be zero. Silent truncation here under-allocatespartialStatsBufferSizeand can produce OOB writes/reads viaptrPartialStats/ptrPartialO.🛡️ Proposed fix to add the divisibility guard
int const numHeadsPerCta_ = kernelMeta.mGroupsHeadsQ ? std::min(options.mNumHeadsQPerKv, kernelMeta.mStepQ) : 1; + FLASHINFER_CHECK(options.mNumHeadsQ % numHeadsPerCta_ == 0, + "numHeadsQ (%d) must be divisible by numHeadsPerCta (%d)", + options.mNumHeadsQ, numHeadsPerCta_); int const numCtasForAllHeads_ = options.mNumHeadsQ / numHeadsPerCta_;🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 832 - 838, Ensure the number of Q heads is divisible by heads-per-CTA before using it as a divisor: when kernelMeta.mGroupsHeadsQ is true compute numHeadsPerCta_ as shown, then assert or handle the case where options.mNumHeadsQ % numHeadsPerCta_ != 0 (e.g., return error, throw, or round up and document) before computing numCtasForAllHeads_ and partialStatsBufferSize so you don't silently truncate and under-allocate; update the code paths that use numCtasForAllHeads_, partialStatsBufferSize, and pointers ptrPartialStats/ptrPartialO to rely only on the validated/divisible value.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 832-838: Ensure the number of Q heads is divisible by
heads-per-CTA before using it as a divisor: when kernelMeta.mGroupsHeadsQ is
true compute numHeadsPerCta_ as shown, then assert or handle the case where
options.mNumHeadsQ % numHeadsPerCta_ != 0 (e.g., return error, throw, or round
up and document) before computing numCtasForAllHeads_ and partialStatsBufferSize
so you don't silently truncate and under-allocate; update the code paths that
use numCtasForAllHeads_, partialStatsBufferSize, and pointers
ptrPartialStats/ptrPartialO to rely only on the validated/divisible value.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f22251bc-3145-4157-a0a8-c2cc517ec89c
📒 Files selected for processing (2)
benchmarks/bench_trtllm_gen_mla.pyinclude/flashinfer/trtllm/fmha/kernelParams.h
| // When the longest sequence needs more KV splits than the standard | ||
| // heuristic provides, oversubscribe the SMs. This helps mixed-length batches | ||
| // where long sequences get insufficient KV parallelism. | ||
| int constexpr kMinTokensPerCta = 2048; |
There was a problem hiding this comment.
The three constant values are too specific and might impact other cases. I would make this as an optional and user-provided options (using ENVs or something).
or I am thinking that we can have a callback function or something so that the heuristic function can be passed.
@bkryu might have better suggestions.
There was a problem hiding this comment.
I added ENVs and disabled by default.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
benchmarks/bench_trtllm_gen_mla.py (2)
123-132: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick winEnable CUPTI path in benchmark timing configuration.
Line 129 hard-disables CUPTI (
enable_cupti=False), which bypasses the intended CUPTI-with-fallback benchmarking path.Proposed fix
measurements = bench_gpu_time( lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( **common_kwargs ), dry_run_iters=5, repeat_iters=30, - enable_cupti=False, + enable_cupti=True, use_cuda_graph=True, cold_l2_cache=True, )As per coding guidelines: "
benchmarks/**/*.py: Useflashinfer.testing.bench_gpu_time()for benchmarking with CUPTI support and automatic fallback to CUDA events".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@benchmarks/bench_trtllm_gen_mla.py` around lines 123 - 132, The benchmark currently disables CUPTI by passing enable_cupti=False to bench_gpu_time, which prevents exercising the CUPTI-with-fallback path; update the bench call that wraps flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the lambda passed to bench_gpu_time) to set enable_cupti=True so bench_gpu_time can attempt CUPTI and automatically fall back to CUDA events if unavailable, leaving other parameters (dry_run_iters, repeat_iters, use_cuda_graph, cold_l2_cache) unchanged.
42-47:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winDon’t use
assertfor runtime input validation.Line 43 can be compiled out with
python -O; use explicit checks and reject non-positive sequence lengths to avoid downstream invalid block math.Proposed fix
if seq_lens_list is not None: - assert len(seq_lens_list) == batch_size, ( - f"seq_lens_list length {len(seq_lens_list)} != batch_size {batch_size}" - ) + if len(seq_lens_list) != batch_size: + raise ValueError( + f"seq_lens_list length {len(seq_lens_list)} != batch_size {batch_size}" + ) seq_lens = list(seq_lens_list) + if any(s <= 0 for s in seq_lens): + raise ValueError("All sequence lengths in seq_lens_list must be > 0") max_seq_len = max(seq_lens)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@benchmarks/bench_trtllm_gen_mla.py` around lines 42 - 47, Replace the assert-based check for seq_lens_list with explicit runtime validation: if seq_lens_list is not None, verify len(seq_lens_list) == batch_size and raise a ValueError with a clear message if not; convert seq_lens_list to seq_lens = list(seq_lens_list) and validate every element is an int > 0, raising ValueError for any non-positive or non-int entry; then compute max_seq_len = max(seq_lens) as before. Use the identifiers seq_lens_list, batch_size, seq_lens, and max_seq_len to locate and update the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Line 88: The workspace_buffer is allocated with torch.empty which leaves
uninitialized bytes and can corrupt the decode counter region; replace the
allocation with a zero-initialized buffer or explicitly zero it after creation
(e.g., use torch.zeros with same shape, dtype and device or call
workspace_buffer.zero_()) so the counter region is deterministic before first
kernel use; refer to workspace_buffer in the allocation site to apply this
change.
---
Outside diff comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Around line 123-132: The benchmark currently disables CUPTI by passing
enable_cupti=False to bench_gpu_time, which prevents exercising the
CUPTI-with-fallback path; update the bench call that wraps
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla (the lambda passed to
bench_gpu_time) to set enable_cupti=True so bench_gpu_time can attempt CUPTI and
automatically fall back to CUDA events if unavailable, leaving other parameters
(dry_run_iters, repeat_iters, use_cuda_graph, cold_l2_cache) unchanged.
- Around line 42-47: Replace the assert-based check for seq_lens_list with
explicit runtime validation: if seq_lens_list is not None, verify
len(seq_lens_list) == batch_size and raise a ValueError with a clear message if
not; convert seq_lens_list to seq_lens = list(seq_lens_list) and validate every
element is an int > 0, raising ValueError for any non-positive or non-int entry;
then compute max_seq_len = max(seq_lens) as before. Use the identifiers
seq_lens_list, batch_size, seq_lens, and max_seq_len to locate and update the
code.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3a5bb811-e27e-4a4f-9c4b-fa5e19b8decd
📒 Files selected for processing (4)
benchmarks/bench_trtllm_gen_mla.pycsrc/fmhaReduction.cuinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/kernelParams.h
💤 Files with no reviewable changes (2)
- include/flashinfer/trtllm/fmha/kernelParams.h
- include/flashinfer/trtllm/fmha/fmhaKernels.cuh
| # Allocate workspace buffer | ||
| # todo(Yingyi): calculate the actual size of workspace buffer | ||
| workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) | ||
| workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) |
There was a problem hiding this comment.
Zero-initialize decode workspace before first kernel use.
Line 88 allocates workspace_buffer with torch.empty(...), but this path requires zeroed workspace for the counter region; uninitialized bytes can make results/timings flaky.
Proposed fix
- workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
+ workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8, device=device)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmarks/bench_trtllm_gen_mla.py` at line 88, The workspace_buffer is
allocated with torch.empty which leaves uninitialized bytes and can corrupt the
decode counter region; replace the allocation with a zero-initialized buffer or
explicitly zero it after creation (e.g., use torch.zeros with same shape, dtype
and device or call workspace_buffer.zero_()) so the counter region is
deterministic before first kernel use; refer to workspace_buffer in the
allocation site to apply this change.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@include/flashinfer/trtllm/common.h`:
- Around line 200-203: getIntEnv uses std::atoi which silently returns 0 for
invalid input leading to a possible division-by-zero downstream (see
fmhaKernels.cuh use). Replace or augment
getIntEnv/getEnvKvOversubMinTokensPerCta to validate the environment value:
parse with std::stoi inside a try/catch (or check that the env string is all
digits) and enforce a safe minimum (e.g., at least 1) before returning; on parse
failure or out-of-range values fall back to the provided defaultVal and log or
warn. Ensure you reference and update getIntEnv (and the specific getter
getEnvKvOversubMinTokensPerCta if present) so callers never receive zero from
invalid env strings.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7727c285-de88-48db-be17-a3ecb339f748
📒 Files selected for processing (2)
include/flashinfer/trtllm/common.hinclude/flashinfer/trtllm/fmha/fmhaKernels.cuh
| inline static int getIntEnv(char const* name, int defaultVal) { | ||
| char const* env = std::getenv(name); | ||
| return env ? std::atoi(env) : defaultVal; | ||
| } |
There was a problem hiding this comment.
std::atoi silently returns 0 for invalid input, risking division by zero.
If FLASHINFER_KV_OVERSUB_MIN_TOKENS_PER_CTA is set to a non-numeric string (e.g., "abc"), atoi returns 0. This value is used as a divisor in fmhaKernels.cuh line 504, causing a crash.
Consider adding validation in getEnvKvOversubMinTokensPerCta() or using std::stoi with exception handling (consistent with csrc/nv_internal/cpp/common/envUtils.cpp).
🛡️ Proposed fix: add minimum bound in getter
inline int getEnvKvOversubMinTokensPerCta() {
- static int const val = getIntEnv("FLASHINFER_KV_OVERSUB_MIN_TOKENS_PER_CTA", 2048);
+ static int const val = std::max(1, getIntEnv("FLASHINFER_KV_OVERSUB_MIN_TOKENS_PER_CTA", 2048));
return val;
}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@include/flashinfer/trtllm/common.h` around lines 200 - 203, getIntEnv uses
std::atoi which silently returns 0 for invalid input leading to a possible
division-by-zero downstream (see fmhaKernels.cuh use). Replace or augment
getIntEnv/getEnvKvOversubMinTokensPerCta to validate the environment value:
parse with std::stoi inside a try/catch (or check that the env string is all
digits) and enforce a safe minimum (e.g., at least 1) before returning; on parse
failure or out-of-range values fall back to the provided defaultVal and log or
warn. Ensure you reference and update getIntEnv (and the specific getter
getEnvKvOversubMinTokensPerCta if present) so callers never receive zero from
invalid env strings.
📌 Description
When a batch contains mixed sequence lengths, the existing heuristic in computeCtaAndClusterConfig assigns numCtasPerSeqKv based on filling one wave of SM occupancy. This under-splits long sequences.
This PR allows SM oversubscription — splitting KV beyond one wave when max_seq_len demands it. Short sequences' extra CTAs early-exit. Three tunable constants control the behavior: kMinTokensPerCta, kMaxOccupancyWaves, and kMaxSplits.
The biggest gains are visible when the distribution is close to prod-like, for that I added distribution
[4096, 8192, 16384, 32768, 65536, 131072]with weights[35, 20, 15, 12, 10, 8]and benchmarks based on that.Benchmarks for DSv3 Running TP8
constants tunning:
Results from
benchmarks/bench_trtllm_gen_mla.pyrandom context len:
Context length distribution:
[4096, 8192, 16384, 32768, 65536, 131072]with weights[35, 20, 15, 12, 10, 8].✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit
Bug Fixes
Performance Improvements
New Features