Skip to content

[WIP] feat: autotune between trtllm-gen and cute-dsl in MLA decode#3086

Closed
nv-yunzheq wants to merge 1 commit into
mainfrom
feat/mla-decode-autotune
Closed

[WIP] feat: autotune between trtllm-gen and cute-dsl in MLA decode#3086
nv-yunzheq wants to merge 1 commit into
mainfrom
feat/mla-decode-autotune

Conversation

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq commented Apr 16, 2026

When backend="auto" on SM100/SM103 and both backends are viable for the given parameters, use the FlashInfer AutoTuner to profile both trtllm-gen and cute-dsl and select the faster one per batch size.

Benchmarks on GB100 show neither backend dominates:

  • cute-dsl is up to 1.51x faster at small batch sizes (1-16)
  • trtllm-gen is up to 1.14x faster at large batch sizes (256-1024)
  • The crossover depends on batch size and sequence length

The autotuner caches results per (batch_size_bucket) so profiling only happens once per shape. Users can persist tuning results via:
with flashinfer.autotuner.autotune(cache="mla_tune.json"):
model(inputs)

When cute-dsl is not available (cutlass-dsl not installed) or the call uses features unsupported by cute-dsl (tensor scales, sinks, sparse MLA, skip-softmax), the autotuner is bypassed and trtllm-gen is selected directly, preserving the original behavior.

Addresses: #2891

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added automatic backend selection optimization for MLA decode operations on SM100/SM103 GPUs. The system now intelligently chooses the most suitable backend based on your configuration.
  • Improvements

    • Refactored backend dispatch logic with enhanced internal optimizations for better performance and maintainability.

When `backend="auto"` on SM100/SM103 and both backends are viable for
the given parameters, use the FlashInfer AutoTuner to profile both
trtllm-gen and cute-dsl and select the faster one per batch size.

Benchmarks on GB100 show neither backend dominates:
- cute-dsl is up to 1.51x faster at small batch sizes (1-16)
- trtllm-gen is up to 1.14x faster at large batch sizes (256-1024)
- The crossover depends on batch size and sequence length

The autotuner caches results per (batch_size_bucket) so profiling
only happens once per shape. Users can persist tuning results via:
  with flashinfer.autotuner.autotune(cache="mla_tune.json"):
      model(inputs)

When cute-dsl is not available (cutlass-dsl not installed) or the
call uses features unsupported by cute-dsl (tensor scales, sinks,
sparse MLA, skip-softmax), the autotuner is bypassed and trtllm-gen
is selected directly, preserving the original behavior.

Addresses: #2891

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 16, 2026

📝 Walkthrough

Walkthrough

This pull request adds autotuner support to the MLA decode function for SM100/SM103 GPUs, enabling dynamic selection between trtllm-gen and cute-dsl backends. It introduces backend execution wrappers, TunableRunner implementations, and refactors direct backend branches to centralize logic.

Changes

Cohort / File(s) Summary
Autotuner Infrastructure
flashinfer/mla/_core.py
Added _last_power_of_2 helper, _mla_tuning_buckets/_MLA_TUNING_CONFIG for batch-size bucket-based tuning configuration, and _is_cute_dsl_viable to gate backend viability (rejects bmm*_scale, sinks, sparse MLA, skip-softmax; requires uses_shared_paged_kv_idx).
Backend Execution Wrappers
flashinfer/mla/_core.py
Introduced _run_trtllm_gen_mla and _run_cute_dsl_mla to centralize shape validation, output allocation, block-size checks, and underlying backend calls.
TunableRunner Implementations
flashinfer/mla/_core.py
Added _TrtllmGenMLARunner and _CuteDslMLARunner classes implementing the TunableRunner interface to enable integration with AutoTuner.get().choose_one(...).
Backend Integration & Refactoring
flashinfer/mla/_core.py
Refactored backend branches (backend=="trtllm-gen" and backend=="cute-dsl") to delegate to new _run_* helpers; integrated autotuner dispatch for SM100/SM103 when backend="auto"; added module-level logger; updated docstring for trtllm_batch_decode_with_kv_cache_mla.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant MLA as trtllm_batch_decode_with_kv_cache_mla
    participant Auto as AutoTuner
    participant TrtRunner as TrtllmGenMLARunner
    participant CuteRunner as CuteDslMLARunner
    
    User->>MLA: Call with backend="auto"<br/>(SM100/SM103)
    MLA->>MLA: Check _is_cute_dsl_viable()
    alt Cute-DSL Viable
        MLA->>Auto: choose_one(TrtRunner, CuteRunner)
        Auto->>TrtRunner: profile forward()
        TrtRunner->>TrtRunner: _run_trtllm_gen_mla()
        Auto->>CuteRunner: profile forward()
        CuteRunner->>CuteRunner: _run_cute_dsl_mla()
        Auto->>MLA: return faster runner
        MLA->>MLA: Execute selected runner
    else Cute-DSL Not Viable
        MLA->>MLA: Use trtllm-gen via<br/>_run_trtllm_gen_mla()
    end
    MLA->>User: return output
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly Related Issues

Possibly Related PRs

Suggested Labels

run-ci

Suggested Reviewers

  • cyx-6
  • bkryu
  • yzh119
  • yongwww
  • sricketts
  • aleozlx

Poem

🐰 Hops with glee through tuning buckets so fine,
Autotuner picks the backend—trtllm-gen or cute divine!
SM100, SM103, now choose with care,
Faster MLA decode floats through the air!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description provides clear context about the autotuning feature, benchmarks, behavior on incompatible features, and links the related issue. However, the required checklist items (pre-commit, tests) are marked unchecked and not addressed. Confirm whether pre-commit checks have been run and tests have been added/updated. If checks are incomplete, complete them or clarify the [WIP] status intent.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: introducing autotuning between two MLA decode backends on SM100/SM103.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/mla-decode-autotune

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/mla/_core.py (1)

1-46: ⚠️ Potential issue | 🟡 Minor

Address pre-commit formatting failure.

The pipeline indicates that the ruff-format hook modified files. Please run pre-commit run --all-files locally to apply the required formatting changes before merging.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` around lines 1 - 46, The file
flashinfer/mla/_core.py failed the pre-commit ruff-format check; run the
repository pre-commit hooks (e.g., pre-commit run --all-files) or run
ruff-format locally and re-commit the changes to fix formatting issues in this
module (adjust imports, spacing, and line breaks around the top-level imports
and the logger = logging.getLogger(__name__) line); ensure the updated
formatting is committed so the CI hook no longer modifies flashinfer.mla._core.
🧹 Nitpick comments (1)
flashinfer/mla/_core.py (1)

679-681: Consider tuple unpacking for clarity.

Static analysis suggests using tuple unpacking instead of concatenation.

♻️ Suggested change
     if out is None:
-        out_shape = query.shape[:-1] + (kv_lora_rank,)
+        out_shape = (*query.shape[:-1], kv_lora_rank)
         out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` around lines 679 - 681, Replace the
list-concatenation style for building the output shape with tuple unpacking to
improve clarity: when handling the case where out is None, construct out_shape
by unpacking query.shape[:-1] and appending kv_lora_rank (e.g., using
(*query.shape[:-1], kv_lora_rank)), then call torch.empty with that out_shape
and the same dtype and device parameters (dtype=torch.bfloat16,
device=query.device) to allocate out.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 621-630: The gen_tuning_buckets field on the DynamicTensorSpec
expects a tuple, but _mla_tuning_buckets(2048) returns a list; update the
assignment in _MLA_TUNING_CONFIG so gen_tuning_buckets is a tuple (e.g., wrap
the call with tuple(...)) or change _mla_tuning_buckets to return a tuple;
adjust the code referencing _MLA_TUNING_CONFIG / DynamicTensorSpec /
TuningConfig accordingly to ensure gen_tuning_buckets has type tuple[int, ...].
- Around line 924-933: The autotuner path ignores a user-provided out tensor
because the call to runner.forward after tuner.choose_one (involving
tuner.choose_one("flashinfer::mla_decode", ...)) does not pass the out argument
and both _TrtllmGenMLARunner.forward and _CuteDslMLARunner.forward currently
treat out as None; fix by changing the call that returns runner.forward(...) to
pass out as a keyword (out=out) and update both _TrtllmGenMLARunner.forward and
_CuteDslMLARunner.forward signatures/implementations to accept an out kwarg and
use it for the output buffer instead of allocating a new tensor when out is
provided.

---

Outside diff comments:
In `@flashinfer/mla/_core.py`:
- Around line 1-46: The file flashinfer/mla/_core.py failed the pre-commit
ruff-format check; run the repository pre-commit hooks (e.g., pre-commit run
--all-files) or run ruff-format locally and re-commit the changes to fix
formatting issues in this module (adjust imports, spacing, and line breaks
around the top-level imports and the logger = logging.getLogger(__name__) line);
ensure the updated formatting is committed so the CI hook no longer modifies
flashinfer.mla._core.

---

Nitpick comments:
In `@flashinfer/mla/_core.py`:
- Around line 679-681: Replace the list-concatenation style for building the
output shape with tuple unpacking to improve clarity: when handling the case
where out is None, construct out_shape by unpacking query.shape[:-1] and
appending kv_lora_rank (e.g., using (*query.shape[:-1], kv_lora_rank)), then
call torch.empty with that out_shape and the same dtype and device parameters
(dtype=torch.bfloat16, device=query.device) to allocate out.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 470a2e9b-8b92-42d9-bb2f-956700cf3634

📥 Commits

Reviewing files that changed from the base of the PR and between a99ee72 and c1b0b16.

📒 Files selected for processing (1)
  • flashinfer/mla/_core.py

Comment thread flashinfer/mla/_core.py
Comment on lines +621 to +630
_MLA_TUNING_CONFIG = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
input_idx=(0,), # query tensor
dim_idx=(0,), # batch_size dimension
gen_tuning_buckets=_mla_tuning_buckets(2048),
map_to_tuning_buckets=lambda x: min(_last_power_of_2(x), 2048),
),
),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix type error: gen_tuning_buckets expects a tuple, not a list.

The pipeline is failing because DynamicTensorSpec.gen_tuning_buckets expects tuple[int, ...] but _mla_tuning_buckets() returns a list.

🔧 Proposed fix
-            gen_tuning_buckets=_mla_tuning_buckets(2048),
+            gen_tuning_buckets=tuple(_mla_tuning_buckets(2048)),

Alternatively, modify _mla_tuning_buckets to return a tuple:

 def _mla_tuning_buckets(max_batch: int = 2048) -> list:
     """Generate power-of-2 tuning buckets for batch size."""
-    buckets = []
+    buckets: list[int] = []
     b = 1
     while b <= max_batch:
         buckets.append(b)
         b *= 2
-    return buckets
+    return tuple(buckets)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_MLA_TUNING_CONFIG = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
input_idx=(0,), # query tensor
dim_idx=(0,), # batch_size dimension
gen_tuning_buckets=_mla_tuning_buckets(2048),
map_to_tuning_buckets=lambda x: min(_last_power_of_2(x), 2048),
),
),
)
_MLA_TUNING_CONFIG = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
input_idx=(0,), # query tensor
dim_idx=(0,), # batch_size dimension
gen_tuning_buckets=tuple(_mla_tuning_buckets(2048)),
map_to_tuning_buckets=lambda x: min(_last_power_of_2(x), 2048),
),
),
)
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 626-626: mypy error: Argument "gen_tuning_buckets" to "DynamicTensorSpec" has incompatible type "list[Any]"; expected "tuple[int, ...] | Callable[..., Any]" [arg-type]

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` around lines 621 - 630, The gen_tuning_buckets field
on the DynamicTensorSpec expects a tuple, but _mla_tuning_buckets(2048) returns
a list; update the assignment in _MLA_TUNING_CONFIG so gen_tuning_buckets is a
tuple (e.g., wrap the call with tuple(...)) or change _mla_tuning_buckets to
return a tuple; adjust the code referencing _MLA_TUNING_CONFIG /
DynamicTensorSpec / TuningConfig accordingly to ensure gen_tuning_buckets has
type tuple[int, ...].

Comment thread flashinfer/mla/_core.py
Comment on lines +924 to +933
runner, _tactic = tuner.choose_one(
"flashinfer::mla_decode",
[trtllm_runner, cute_dsl_runner],
_MLA_TUNING_CONFIG,
[query, kv_cache, workspace_buffer, block_tables, seq_lens],
)
return runner.forward(
[query, kv_cache, workspace_buffer, block_tables, seq_lens],
tactic=_tactic,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

User-provided out tensor is ignored when autotuning.

When the autotuner path is taken, the out parameter passed by the user is not forwarded to the runner. Both _TrtllmGenMLARunner.forward and _CuteDslMLARunner.forward pass None for out, causing a new output tensor to be allocated internally even if the caller provided one.

This could cause unexpected behavior for users who pass a pre-allocated output tensor (e.g., for CUDA graph capture scenarios).

🔧 Proposed fix

Pass out as a keyword argument to the runners:

                 runner, _tactic = tuner.choose_one(
                     "flashinfer::mla_decode",
                     [trtllm_runner, cute_dsl_runner],
                     _MLA_TUNING_CONFIG,
                     [query, kv_cache, workspace_buffer, block_tables, seq_lens],
                 )
                 return runner.forward(
                     [query, kv_cache, workspace_buffer, block_tables, seq_lens],
                     tactic=_tactic,
+                    out=out,
                 )

And update the runner's forward methods to accept and use the out kwarg:

     def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
         query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
         max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
+        out = kwargs.get("out", None)
         return _run_trtllm_gen_mla(
             query, kv_cache, workspace_buffer,
             self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim,
             block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k,
-            None,  # out
+            out,
             self.bmm1_scale, self.bmm2_scale, self.sinks,
             self.skip_softmax_threshold_scale_factor,
             self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx,
         )

Apply the same change to _CuteDslMLARunner.forward.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` around lines 924 - 933, The autotuner path ignores a
user-provided out tensor because the call to runner.forward after
tuner.choose_one (involving tuner.choose_one("flashinfer::mla_decode", ...))
does not pass the out argument and both _TrtllmGenMLARunner.forward and
_CuteDslMLARunner.forward currently treat out as None; fix by changing the call
that returns runner.forward(...) to pass out as a keyword (out=out) and update
both _TrtllmGenMLARunner.forward and _CuteDslMLARunner.forward
signatures/implementations to accept an out kwarg and use it for the output
buffer instead of allocating a new tensor when out is provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces autotuner integration for MLA decode backend selection, allowing the system to dynamically choose between the trtllm-gen and cute-dsl backends on SM100/SM103 architectures. The review feedback highlights several critical issues regarding the handling of the pre-allocated 'out' tensor, which was being ignored in both the runner implementations and the main autotuning dispatch logic. Additionally, the review identified an incorrect shape validation for the output tensor in multi-token scenarios and suggested an efficiency improvement for module loading.

Comment thread flashinfer/mla/_core.py
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out, [batch_size, num_q_heads, kv_lora_rank],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shape check for the out tensor is incorrect when q_len > 1 (e.g., in MTP scenarios). The query tensor is 4D [batch_size, q_len, num_heads, head_dim], so the output tensor should also be 4D [batch_size, q_len, num_heads, kv_lora_rank]. The current check uses a 3D shape, which will cause a validation failure. Use the out_shape calculated on line 680 instead.

Suggested change
out, [batch_size, num_q_heads, kv_lora_rank],
out, out_shape,

Comment thread flashinfer/mla/_core.py
Comment on lines +746 to +757
def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
return _run_trtllm_gen_mla(
query, kv_cache, workspace_buffer,
self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim,
block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k,
None, # out
self.bmm1_scale, self.bmm2_scale, self.sinks,
self.skip_softmax_threshold_scale_factor,
self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward method of _TrtllmGenMLARunner ignores the out tensor if it is provided by the user. It hardcodes None when calling _run_trtllm_gen_mla, which causes a new tensor to be allocated and returned, even if the user passed a pre-allocated buffer. This breaks the expected behavior of the out parameter in the public API.

Suggested change
def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
return _run_trtllm_gen_mla(
query, kv_cache, workspace_buffer,
self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim,
block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k,
None, # out
self.bmm1_scale, self.bmm2_scale, self.sinks,
self.skip_softmax_threshold_scale_factor,
self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx,
)
def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
out = kwargs.get("out")
return _run_trtllm_gen_mla(
query, kv_cache, workspace_buffer,
self.qk_nope_head_dim, self.kv_lora_rank, self.qk_rope_head_dim,
block_tables, seq_lens, max_seq_len, self.sparse_mla_top_k,
out,
self.bmm1_scale, self.bmm2_scale, self.sinks,
self.skip_softmax_threshold_scale_factor,
self.enable_pdl, self.is_var_seq, self.uses_shared_paged_kv_idx,
)

Comment thread flashinfer/mla/_core.py
Comment on lines +775 to +785
def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
return _run_cute_dsl_mla(
query, kv_cache, workspace_buffer,
self.kv_lora_rank, self.qk_rope_head_dim,
block_tables, seq_lens, max_seq_len,
self.bmm1_scale, self.bmm2_scale,
None, # out
self.is_var_seq, self.enable_pdl,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the TRTLLM runner, the _CuteDslMLARunner also ignores the user-provided out tensor. Please pass it through from kwargs.

Suggested change
def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
return _run_cute_dsl_mla(
query, kv_cache, workspace_buffer,
self.kv_lora_rank, self.qk_rope_head_dim,
block_tables, seq_lens, max_seq_len,
self.bmm1_scale, self.bmm2_scale,
None, # out
self.is_var_seq, self.enable_pdl,
)
def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs):
query, kv_cache, workspace_buffer, block_tables, seq_lens = inputs
max_seq_len = int(seq_lens.max().item()) if seq_lens.numel() > 0 else 0
out = kwargs.get("out")
return _run_cute_dsl_mla(
query, kv_cache, workspace_buffer,
self.kv_lora_rank, self.qk_rope_head_dim,
block_tables, seq_lens, max_seq_len,
self.bmm1_scale, self.bmm2_scale,
out,
self.is_var_seq, self.enable_pdl,
)

Comment thread flashinfer/mla/_core.py
Comment on lines +930 to +933
return runner.forward(
[query, kv_cache, workspace_buffer, block_tables, seq_lens],
tactic=_tactic,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When executing the selected runner after autotuning, the out tensor provided by the user is not passed to the forward call. This will result in the user-provided buffer being ignored and a new tensor being returned instead.

Suggested change
return runner.forward(
[query, kv_cache, workspace_buffer, block_tables, seq_lens],
tactic=_tactic,
)
return runner.forward(
[query, kv_cache, workspace_buffer, block_tables, seq_lens],
tactic=_tactic,
out=out,
)

Comment thread flashinfer/mla/_core.py
enable_pdl, is_var_seq, uses_shared_paged_kv_idx,
):
"""Execute trtllm-gen backend (shared by runner and direct dispatch)."""
run_func = gen_trtllm_gen_fmha_module().trtllm_paged_attention_decode
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using gen_trtllm_gen_fmha_module() directly here is inefficient as it recreates the module object on every call. Please use the cached version get_trtllm_gen_fmha_module() instead, which is already defined in this file.

Suggested change
run_func = gen_trtllm_gen_fmha_module().trtllm_paged_attention_decode
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode

@jwu1980 jwu1980 closed this Apr 16, 2026
@jwu1980 jwu1980 deleted the feat/mla-decode-autotune branch April 16, 2026 21:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants