[WIP] feat: autotune between trtllm-gen and cute-dsl in MLA decode#3086
[WIP] feat: autotune between trtllm-gen and cute-dsl in MLA decode#3086nv-yunzheq wants to merge 1 commit into
Conversation
When `backend="auto"` on SM100/SM103 and both backends are viable for
the given parameters, use the FlashInfer AutoTuner to profile both
trtllm-gen and cute-dsl and select the faster one per batch size.
Benchmarks on GB100 show neither backend dominates:
- cute-dsl is up to 1.51x faster at small batch sizes (1-16)
- trtllm-gen is up to 1.14x faster at large batch sizes (256-1024)
- The crossover depends on batch size and sequence length
The autotuner caches results per (batch_size_bucket) so profiling
only happens once per shape. Users can persist tuning results via:
with flashinfer.autotuner.autotune(cache="mla_tune.json"):
model(inputs)
When cute-dsl is not available (cutlass-dsl not installed) or the
call uses features unsupported by cute-dsl (tensor scales, sinks,
sparse MLA, skip-softmax), the autotuner is bypassed and trtllm-gen
is selected directly, preserving the original behavior.
Addresses: #2891
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis pull request adds autotuner support to the MLA decode function for SM100/SM103 GPUs, enabling dynamic selection between Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant MLA as trtllm_batch_decode_with_kv_cache_mla
participant Auto as AutoTuner
participant TrtRunner as TrtllmGenMLARunner
participant CuteRunner as CuteDslMLARunner
User->>MLA: Call with backend="auto"<br/>(SM100/SM103)
MLA->>MLA: Check _is_cute_dsl_viable()
alt Cute-DSL Viable
MLA->>Auto: choose_one(TrtRunner, CuteRunner)
Auto->>TrtRunner: profile forward()
TrtRunner->>TrtRunner: _run_trtllm_gen_mla()
Auto->>CuteRunner: profile forward()
CuteRunner->>CuteRunner: _run_cute_dsl_mla()
Auto->>MLA: return faster runner
MLA->>MLA: Execute selected runner
else Cute-DSL Not Viable
MLA->>MLA: Use trtllm-gen via<br/>_run_trtllm_gen_mla()
end
MLA->>User: return output
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly Related Issues
Possibly Related PRs
Suggested Labels
Suggested Reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/mla/_core.py (1)
1-46:⚠️ Potential issue | 🟡 MinorAddress pre-commit formatting failure.
The pipeline indicates that the
ruff-formathook modified files. Please runpre-commit run --all-fileslocally to apply the required formatting changes before merging.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla/_core.py` around lines 1 - 46, The file flashinfer/mla/_core.py failed the pre-commit ruff-format check; run the repository pre-commit hooks (e.g., pre-commit run --all-files) or run ruff-format locally and re-commit the changes to fix formatting issues in this module (adjust imports, spacing, and line breaks around the top-level imports and the logger = logging.getLogger(__name__) line); ensure the updated formatting is committed so the CI hook no longer modifies flashinfer.mla._core.
🧹 Nitpick comments (1)
flashinfer/mla/_core.py (1)
679-681: Consider tuple unpacking for clarity.Static analysis suggests using tuple unpacking instead of concatenation.
♻️ Suggested change
if out is None: - out_shape = query.shape[:-1] + (kv_lora_rank,) + out_shape = (*query.shape[:-1], kv_lora_rank) out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla/_core.py` around lines 679 - 681, Replace the list-concatenation style for building the output shape with tuple unpacking to improve clarity: when handling the case where out is None, construct out_shape by unpacking query.shape[:-1] and appending kv_lora_rank (e.g., using (*query.shape[:-1], kv_lora_rank)), then call torch.empty with that out_shape and the same dtype and device parameters (dtype=torch.bfloat16, device=query.device) to allocate out.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 621-630: The gen_tuning_buckets field on the DynamicTensorSpec
expects a tuple, but _mla_tuning_buckets(2048) returns a list; update the
assignment in _MLA_TUNING_CONFIG so gen_tuning_buckets is a tuple (e.g., wrap
the call with tuple(...)) or change _mla_tuning_buckets to return a tuple;
adjust the code referencing _MLA_TUNING_CONFIG / DynamicTensorSpec /
TuningConfig accordingly to ensure gen_tuning_buckets has type tuple[int, ...].
- Around line 924-933: The autotuner path ignores a user-provided out tensor
because the call to runner.forward after tuner.choose_one (involving
tuner.choose_one("flashinfer::mla_decode", ...)) does not pass the out argument
and both _TrtllmGenMLARunner.forward and _CuteDslMLARunner.forward currently
treat out as None; fix by changing the call that returns runner.forward(...) to
pass out as a keyword (out=out) and update both _TrtllmGenMLARunner.forward and
_CuteDslMLARunner.forward signatures/implementations to accept an out kwarg and
use it for the output buffer instead of allocating a new tensor when out is
provided.
---
Outside diff comments:
In `@flashinfer/mla/_core.py`:
- Around line 1-46: The file flashinfer/mla/_core.py failed the pre-commit
ruff-format check; run the repository pre-commit hooks (e.g., pre-commit run
--all-files) or run ruff-format locally and re-commit the changes to fix
formatting issues in this module (adjust imports, spacing, and line breaks
around the top-level imports and the logger = logging.getLogger(__name__) line);
ensure the updated formatting is committed so the CI hook no longer modifies
flashinfer.mla._core.
---
Nitpick comments:
In `@flashinfer/mla/_core.py`:
- Around line 679-681: Replace the list-concatenation style for building the
output shape with tuple unpacking to improve clarity: when handling the case
where out is None, construct out_shape by unpacking query.shape[:-1] and
appending kv_lora_rank (e.g., using (*query.shape[:-1], kv_lora_rank)), then
call torch.empty with that out_shape and the same dtype and device parameters
(dtype=torch.bfloat16, device=query.device) to allocate out.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
| _MLA_TUNING_CONFIG = TuningConfig( | ||
| dynamic_tensor_specs=( | ||
| DynamicTensorSpec( | ||
| input_idx=(0,), # query tensor | ||
| dim_idx=(0,), # batch_size dimension | ||
| gen_tuning_buckets=_mla_tuning_buckets(2048), | ||
| map_to_tuning_buckets=lambda x: min(_last_power_of_2(x), 2048), | ||
| ), | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Fix type error: gen_tuning_buckets expects a tuple, not a list.
The pipeline is failing because DynamicTensorSpec.gen_tuning_buckets expects tuple[int, ...] but _mla_tuning_buckets() returns a list.
🔧 Proposed fix
- gen_tuning_buckets=_mla_tuning_buckets(2048),
+ gen_tuning_buckets=tuple(_mla_tuning_buckets(2048)),Alternatively, modify _mla_tuning_buckets to return a tuple:
def _mla_tuning_buckets(max_batch: int = 2048) -> list:
"""Generate power-of-2 tuning buckets for batch size."""
- buckets = []
+ buckets: list[int] = []
b = 1
while b <= max_batch:
buckets.append(b)
b *= 2
- return buckets
+ return tuple(buckets)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| _MLA_TUNING_CONFIG = TuningConfig( | |
| dynamic_tensor_specs=( | |
| DynamicTensorSpec( | |
| input_idx=(0,), # query tensor | |
| dim_idx=(0,), # batch_size dimension | |
| gen_tuning_buckets=_mla_tuning_buckets(2048), | |
| map_to_tuning_buckets=lambda x: min(_last_power_of_2(x), 2048), | |
| ), | |
| ), | |
| ) | |
| _MLA_TUNING_CONFIG = TuningConfig( | |
| dynamic_tensor_specs=( | |
| DynamicTensorSpec( | |
| input_idx=(0,), # query tensor | |
| dim_idx=(0,), # batch_size dimension | |
| gen_tuning_buckets=tuple(_mla_tuning_buckets(2048)), | |
| map_to_tuning_buckets=lambda x: min(_last_power_of_2(x), 2048), | |
| ), | |
| ), | |
| ) |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 626-626: mypy error: Argument "gen_tuning_buckets" to "DynamicTensorSpec" has incompatible type "list[Any]"; expected "tuple[int, ...] | Callable[..., Any]" [arg-type]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/mla/_core.py` around lines 621 - 630, The gen_tuning_buckets field
on the DynamicTensorSpec expects a tuple, but _mla_tuning_buckets(2048) returns
a list; update the assignment in _MLA_TUNING_CONFIG so gen_tuning_buckets is a
tuple (e.g., wrap the call with tuple(...)) or change _mla_tuning_buckets to
return a tuple; adjust the code referencing _MLA_TUNING_CONFIG /
DynamicTensorSpec / TuningConfig accordingly to ensure gen_tuning_buckets has
type tuple[int, ...].
| runner, _tactic = tuner.choose_one( | ||
| "flashinfer::mla_decode", | ||
| [trtllm_runner, cute_dsl_runner], | ||
| _MLA_TUNING_CONFIG, | ||
| [query, kv_cache, workspace_buffer, block_tables, seq_lens], | ||
| ) | ||
| return runner.forward( | ||
| [query, kv_cache, workspace_buffer, block_tables, seq_lens], | ||
| tactic=_tactic, | ||
| ) |
There was a problem hiding this comment.
User-provided out tensor is ignored when autotuning.
When the autotuner path is taken, the out parameter passed by the user is not forwarded to the runner. Both _TrtllmGenMLARunner.forward and _CuteDslMLARunner.forward pass None for out, causing a new output tensor to be allocated internally even if the caller provided one.
This could cause unexpected behavior for users who pass a pre-allocated output tensor (e.g., for CUDA graph capture scenarios).
🔧 Proposed fix
Pass out as a keyword argument to the runners:
runner, _tactic = tuner.choose_one(
"flashinfer::mla_decode",
[trtllm_runner, cute_dsl_runner],
_MLA_TUNING_CONFIG,
[query, kv_cache, workspace_buffer, block_tables, seq_lens],
)
return runner.forward(
[query, kv_cache, workspace_buffer, block_tables, seq_lens],
tactic=_tactic,
+ out=out,
)And update the runner's forward methods to accept and use the out kwarg:
def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
+ out = kwargs.get("out", None)
return _run_trtllm_gen_mla(
query, kv_cache, workspace_buffer,
self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim,
block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k,
- None, # out
+ out,
self.bmm1_scale, self.bmm2_scale, self.sinks,
self.skip_softmax_threshold_scale_factor,
self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx,
)Apply the same change to _CuteDslMLARunner.forward.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/mla/_core.py` around lines 924 - 933, The autotuner path ignores a
user-provided out tensor because the call to runner.forward after
tuner.choose_one (involving tuner.choose_one("flashinfer::mla_decode", ...))
does not pass the out argument and both _TrtllmGenMLARunner.forward and
_CuteDslMLARunner.forward currently treat out as None; fix by changing the call
that returns runner.forward(...) to pass out as a keyword (out=out) and update
both _TrtllmGenMLARunner.forward and _CuteDslMLARunner.forward
signatures/implementations to accept an out kwarg and use it for the output
buffer instead of allocating a new tensor when out is provided.
There was a problem hiding this comment.
Code Review
This pull request introduces autotuner integration for MLA decode backend selection, allowing the system to dynamically choose between the trtllm-gen and cute-dsl backends on SM100/SM103 architectures. The review feedback highlights several critical issues regarding the handling of the pre-allocated 'out' tensor, which was being ignored in both the runner implementations and the main autotuning dispatch logic. Additionally, the review identified an incorrect shape validation for the output tensor in multi-token scenarios and suggested an efficiency improvement for module loading.
| else: | ||
| batch_size, _, num_q_heads, _ = query.shape | ||
| check_shape_dtype_device( | ||
| out, [batch_size, num_q_heads, kv_lora_rank], |
There was a problem hiding this comment.
The shape check for the out tensor is incorrect when q_len > 1 (e.g., in MTP scenarios). The query tensor is 4D [batch_size, q_len, num_heads, head_dim], so the output tensor should also be 4D [batch_size, q_len, num_heads, kv_lora_rank]. The current check uses a 3D shape, which will cause a validation failure. Use the out_shape calculated on line 680 instead.
| out, [batch_size, num_q_heads, kv_lora_rank], | |
| out, out_shape, |
| def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): | ||
| query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs | ||
| max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0 | ||
| return _run_trtllm_gen_mla( | ||
| query, kv_cache, workspace_buffer, | ||
| self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim, | ||
| block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k, | ||
| None, # out | ||
| self.bmm1_scale, self.bmm2_scale, self.sinks, | ||
| self.skip_softmax_threshold_scale_factor, | ||
| self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx, | ||
| ) |
There was a problem hiding this comment.
The forward method of _TrtllmGenMLARunner ignores the out tensor if it is provided by the user. It hardcodes None when calling _run_trtllm_gen_mla, which causes a new tensor to be allocated and returned, even if the user passed a pre-allocated buffer. This breaks the expected behavior of the out parameter in the public API.
| def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): | |
| query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs | |
| max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0 | |
| return _run_trtllm_gen_mla( | |
| query, kv_cache, workspace_buffer, | |
| self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim, | |
| block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k, | |
| None, # out | |
| self.bmm1_scale, self.bmm2_scale, self.sinks, | |
| self.skip_softmax_threshold_scale_factor, | |
| self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx, | |
| ) | |
| def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): | |
| query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs | |
| max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0 | |
| out = kwargs.get("out") | |
| return _run_trtllm_gen_mla( | |
| query, kv_cache, workspace_buffer, | |
| self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim, | |
| block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k, | |
| out, | |
| self.bmm1_scale, self.bmm2_scale, self.sinks, | |
| self.skip_softmax_threshold_scale_factor, | |
| self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx, | |
| ) |
| def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): | ||
| query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs | ||
| max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0 | ||
| return _run_cute_dsl_mla( | ||
| query, kv_cache, workspace_buffer, | ||
| self.kv_lora_rank, self.qk_rope_head_dim, | ||
| block_tables, seq_lens, max_seq_len, | ||
| self.bmm1_scale, self.bmm2_scale, | ||
| None, # out | ||
| self.is_var_seq, self.enable_pdl, | ||
| ) |
There was a problem hiding this comment.
Similar to the TRTLLM runner, the _CuteDslMLARunner also ignores the user-provided out tensor. Please pass it through from kwargs.
| def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): | |
| query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs | |
| max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0 | |
| return _run_cute_dsl_mla( | |
| query, kv_cache, workspace_buffer, | |
| self.kv_lora_rank, self.qk_rope_head_dim, | |
| block_tables, seq_lens, max_seq_len, | |
| self.bmm1_scale, self.bmm2_scale, | |
| None, # out | |
| self.is_var_seq, self.enable_pdl, | |
| ) | |
| def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): | |
| query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs | |
| max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0 | |
| out = kwargs.get("out") | |
| return _run_cute_dsl_mla( | |
| query, kv_cache, workspace_buffer, | |
| self.kv_lora_rank, self.qk_rope_head_dim, | |
| block_tables, seq_lens, max_seq_len, | |
| self.bmm1_scale, self.bmm2_scale, | |
| out, | |
| self.is_var_seq, self.enable_pdl, | |
| ) |
| return runner.forward( | ||
| [query, kv_cache, workspace_buffer, block_tables, seq_lens], | ||
| tactic=_tactic, | ||
| ) |
There was a problem hiding this comment.
When executing the selected runner after autotuning, the out tensor provided by the user is not passed to the forward call. This will result in the user-provided buffer being ignored and a new tensor being returned instead.
| return runner.forward( | |
| [query, kv_cache, workspace_buffer, block_tables, seq_lens], | |
| tactic=_tactic, | |
| ) | |
| return runner.forward( | |
| [query, kv_cache, workspace_buffer, block_tables, seq_lens], | |
| tactic=_tactic, | |
| out=out, | |
| ) |
| enable_pdl, is_var_seq, uses_shared_paged_kv_idx, | ||
| ): | ||
| """Execute trtllm-gen backend (shared by runner and direct dispatch).""" | ||
| run_func = gen_trtllm_gen_fmha_module().trtllm_paged_attention_decode |
There was a problem hiding this comment.
Using gen_trtllm_gen_fmha_module() directly here is inefficient as it recreates the module object on every call. Please use the cached version get_trtllm_gen_fmha_module() instead, which is already defined in this file.
| run_func = gen_trtllm_gen_fmha_module().trtllm_paged_attention_decode | |
| run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode |
When
backend="auto"on SM100/SM103 and both backends are viable for the given parameters, use the FlashInfer AutoTuner to profile both trtllm-gen and cute-dsl and select the faster one per batch size.Benchmarks on GB100 show neither backend dominates:
The autotuner caches results per (batch_size_bucket) so profiling only happens once per shape. Users can persist tuning results via:
with flashinfer.autotuner.autotune(cache="mla_tune.json"):
model(inputs)
When cute-dsl is not available (cutlass-dsl not installed) or the call uses features unsupported by cute-dsl (tensor scales, sinks, sparse MLA, skip-softmax), the autotuner is bypassed and trtllm-gen is selected directly, preserving the original behavior.
Addresses: #2891
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements