Skip to content

Support Kimi K2.5 H64 CuTe DSL MLA decode#3235

Merged
saltyminty merged 4 commits into
mainfrom
fix/mingyangw/support-kimi-k2-5-config-for-cutadsl-mla-decode
May 11, 2026
Merged

Support Kimi K2.5 H64 CuTe DSL MLA decode#3235
saltyminty merged 4 commits into
mainfrom
fix/mingyangw/support-kimi-k2-5-config-for-cutadsl-mla-decode

Conversation

@saltyminty
Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty commented May 5, 2026

#3161

Summary

  • Enable CuTe DSL MLA decode for Kimi K2.5-style 64 query heads.
  • Allow H=64 split-KV by padding split-KV workspace storage to the physical 128-head MMA lane width.
  • Fold 64-head coverage into the existing CuTe DSL MLA decode test sweep.

Notes

An experimental 64-wide MMA-M path was investigated separately, but it was reverted and is not part of this PR. The branch intentionally keeps the GPU-validated padded implementation.

Validation

  • pre-commit run --files flashinfer/cute_dsl/attention/collective_builder.py flashinfer/cute_dsl/attention/mla_config.py flashinfer/cute_dsl/attention/mla_decode.py flashinfer/cute_dsl/attention/mla_decode_fp8.py flashinfer/cute_dsl/attention/scheduler/mla_persistent.py flashinfer/cute_dsl/attention/wrappers/batch_mla.py tests/attention/test_cute_dsl_mla_decode.py
  • SM100 smoke: H=64 CuTe DSL MLA public API with computed split_kv=32, output shape (1, 1, 64, 512), dtype torch.float16, no NaNs.
  • SM100 padded benchmark: B=1/S=128 0.027255 ms, B=1/S=512 0.027000 ms, B=4/S=128 0.026925 ms, B=4/S=512 0.026979 ms.

Summary by CodeRabbit

  • New Features

    • Relaxed head-count constraints so MLA decode accepts smaller head dimensions and more configurations.
  • Bug Fixes

    • Workspace sizing now pads head dimension to 128 for accumulator/layout computations and when computing workspace for split-KV, preventing layout/address issues.
    • Removed runtime rejection for small head counts.
  • Tests

    • Expanded MLA decode tests to cover multiple head counts and added NaN/Inf (finiteness) assertions.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 5, 2026

Review Change Stack
No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cab0b38e-15e2-474e-af5b-604a64262e01

📥 Commits

Reviewing files that changed from the base of the PR and between 67a9838 and b4d193c.

📒 Files selected for processing (7)
  • flashinfer/cute_dsl/attention/mla_config.py
  • flashinfer/cute_dsl/attention/mla_decode.py
  • flashinfer/cute_dsl/attention/mla_decode_fp8.py
  • flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
  • tests/attention/test_cute_dsl_mla_decode.py
  • tests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • flashinfer/cute_dsl/attention/mla_config.py
  • tests/attention/test_trtllm_gen_mla.py
  • flashinfer/cute_dsl/attention/mla_decode_fp8.py
  • flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
  • flashinfer/cute_dsl/attention/mla_decode.py
  • tests/attention/test_cute_dsl_mla_decode.py

📝 Walkthrough

Walkthrough

This PR relaxes MLA decode gating for small heads by removing (H < 128 and split_kv != 1) checks, padding workspace head dimension to 128 for split-KV, updating workspace layouts/strides and workspace-size calculations, removing runtime small-H rejections, and extending tests to cover 64-head cases.

Changes

MLA Small Head Dimension Support

Layer / File(s) Summary
Config Validation
flashinfer/cute_dsl/attention/mla_config.py
can_implement and can_implement_fp8 now only disallow H > 128, removing the previous (H < 128 and split_kv != 1) rejection.
Memory Layout & Workspace Allocation
flashinfer/cute_dsl/attention/mla_decode.py, flashinfer/cute_dsl/attention/mla_decode_fp8.py
initialize_workspace uses workspace_H = max(H, 128) for acc_o/acc_lse shapes and updates stride/align calculations accordingly; acc_lse offset still derived from cosize(acc_o_layout).
Workspace Size Calculation
flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
mla_get_workspace_size uses workspace_heads = max(H, 128) when split_kv != 1 to account for 128-wide MMA-M tile packing.
Runtime Validation Removal
flashinfer/cute_dsl/attention/wrappers/batch_mla.py
Removed runtime enforcement that split_kv == 1 for H < 128; compute split_kv and workspace_size earlier and drop the hard head-count branch.
Test Coverage
tests/attention/test_cute_dsl_mla_decode.py, tests/attention/test_trtllm_gen_mla.py
Parametrize MLA decode tests over num_heads (128, 64), remove hard-coded head counts, add torch.isfinite(out).all() assertions, and change cute-dsl skip guard to a dimension-structure equality check.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related issues

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • yzh119
  • cyx-6
  • bkryu
  • yyihuang
  • kahyunnam
  • jimmyzho

"I hop with padded heads to make decodes right,
I pack lanes to one-two-eight and hold them tight,
Split-KV may touch the padded rows,
Tests try sixty-four and check for flows,
A rabbit cheeps — MLA now sees light!"

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically identifies the main change: enabling CuTe DSL MLA decode for Kimi K2.5 H64 configurations, which is the primary objective of the PR.
Description check ✅ Passed The PR description includes a link to the related issue, a clear summary section explaining the changes (H=64 support via padding), validation results with pre-commit and GPU tests, and notes on the implementation approach.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/mingyangw/support-kimi-k2-5-config-for-cutadsl-mla-decode

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
flashinfer/cute_dsl/attention/mla_config.py (1)

142-184: 💤 Low value

split_kv parameter is now unused in can_implement / can_implement_fp8.

After dropping the small-H gate, neither overload references split_kv. Either re-introduce a check that depends on it (e.g., upper bound vs. MAX_SPLITS/K) or annotate the parameter as intentionally unused so it doesn't read as dead. Keeping the signature stable is fine.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/cute_dsl/attention/mla_config.py` around lines 142 - 184, The
split_kv parameter is currently unused in can_implement and can_implement_fp8;
update these functions to either (A) enforce a meaningful constraint using
split_kv (for example compare split_kv against a MAX_SPLITS constant or derive
an upper bound from K) by adding that check into can_implement /
can_implement_fp8, or (B) explicitly mark split_kv as intentionally unused to
silence dead-parameter concerns (e.g., rename to _split_kv or add an explicit
unused annotation/comment/cast) while keeping the public signature stable; apply
the same choice consistently to both can_implement and can_implement_fp8.
tests/attention/test_cute_dsl_mla_decode.py (1)

309-326: ⚡ Quick win

Add a NaN check to make this a useful regression for the H=64 padding path.

The PR's smoke-test acceptance criterion for H=64 was "no NaNs", but the test only asserts the output shape. Without a finiteness check, a regression where padded lanes leak into the reduction (the exact failure mode this PR's workspace_H = max(H, 128) fix prevents) would still pass.

♻️ Proposed addition
     assert out.shape == (batch_size, q_len, num_heads, latent_dim)
+    assert not torch.isnan(out).any(), "cute-dsl MLA decode produced NaNs"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 309 - 326, The test
currently only checks out.shape after calling
trtllm_batch_decode_with_kv_cache_mla, so a regression that produces NaNs or
Infs (the H=64 padding bug) would still pass; add a finiteness check right after
the call that asserts all elements of out are finite (no NaN/Inf) — reference
the call to trtllm_batch_decode_with_kv_cache_mla and the resulting variable out
and replace/augment the existing shape assertion with an additional assertion
using the test framework's numeric/torch/isfinite check to fail if any element
is not finite and include a clear message.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@flashinfer/cute_dsl/attention/mla_config.py`:
- Around line 142-184: The split_kv parameter is currently unused in
can_implement and can_implement_fp8; update these functions to either (A)
enforce a meaningful constraint using split_kv (for example compare split_kv
against a MAX_SPLITS constant or derive an upper bound from K) by adding that
check into can_implement / can_implement_fp8, or (B) explicitly mark split_kv as
intentionally unused to silence dead-parameter concerns (e.g., rename to
_split_kv or add an explicit unused annotation/comment/cast) while keeping the
public signature stable; apply the same choice consistently to both
can_implement and can_implement_fp8.

In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 309-326: The test currently only checks out.shape after calling
trtllm_batch_decode_with_kv_cache_mla, so a regression that produces NaNs or
Infs (the H=64 padding bug) would still pass; add a finiteness check right after
the call that asserts all elements of out are finite (no NaN/Inf) — reference
the call to trtllm_batch_decode_with_kv_cache_mla and the resulting variable out
and replace/augment the existing shape assertion with an additional assertion
using the test framework's numeric/torch/isfinite check to fail if any element
is not finite and include a clear message.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f9f0d24a-e636-4005-9f79-05ca826cbdc8

📥 Commits

Reviewing files that changed from the base of the PR and between ba30d4f and 7f3bbed.

📒 Files selected for processing (7)
  • flashinfer/cute_dsl/attention/mla_config.py
  • flashinfer/cute_dsl/attention/mla_decode.py
  • flashinfer/cute_dsl/attention/mla_decode_fp8.py
  • flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
  • tests/attention/test_cute_dsl_mla_decode.py
  • tests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for Multi-Head Latent Attention (MLA) with fewer than 128 heads in the cute-dsl backend by padding the workspace to a minimum of 128 heads. This ensures compatibility with the 128-wide MMA-M tile used in the kernels. The changes update workspace initialization, layout calculations, and size estimation for both FP16 and FP8 variants, while also removing outdated runtime validation checks and expanding test coverage. I have no feedback to provide as the existing review comments were purely explanatory and did not identify any issues.

@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !631 has been created, and the CI pipeline #50366592 is currently running. I'll report back once the pipeline job completes.

Comment thread tests/attention/test_cute_dsl_mla_decode.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed.

@saltyminty saltyminty force-pushed the fix/mingyangw/support-kimi-k2-5-config-for-cutadsl-mla-decode branch from 52e4c69 to a925da9 Compare May 7, 2026 18:22
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)

976-1044: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Missing torch.isfinite assertion — inconsistent with the other two updated tests.

test_cute_dsl_mla_decode_via_api (line 327) and test_cute_dsl_mla_decode_fp8 (line 463) both received an explicit torch.isfinite(out).all() guard in this PR, but test_cute_dsl_mla_decode_fp8_alibi did not. Add the check before the assert_close call to keep the pattern uniform.

🛡️ Proposed fix
     out = wrapper.run(...)

+    assert torch.isfinite(out).all(), "FP8 ALiBi cute-dsl MLA decode produced non-finite values"
+
     D_qk = latent_dim + rope_dim
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 976 - 1044, Add a
finite-value guard before the assertion in test_cute_dsl_mla_decode_fp8_alibi:
verify torch.isfinite(out).all() (same pattern used in
test_cute_dsl_mla_decode_via_api and test_cute_dsl_mla_decode_fp8) immediately
before calling torch.testing.assert_close on out so the test checks for
non-NaN/inf values prior to comparison.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 976-1044: Add a finite-value guard before the assertion in
test_cute_dsl_mla_decode_fp8_alibi: verify torch.isfinite(out).all() (same
pattern used in test_cute_dsl_mla_decode_via_api and
test_cute_dsl_mla_decode_fp8) immediately before calling
torch.testing.assert_close on out so the test checks for non-NaN/inf values
prior to comparison.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f2cc339a-c384-474a-9ce5-739266bcfac1

📥 Commits

Reviewing files that changed from the base of the PR and between 7f3bbed and a925da9d071ebe27d0b73c2eb89ae687d7d3e712.

📒 Files selected for processing (7)
  • flashinfer/cute_dsl/attention/mla_config.py
  • flashinfer/cute_dsl/attention/mla_decode.py
  • flashinfer/cute_dsl/attention/mla_decode_fp8.py
  • flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
  • tests/attention/test_cute_dsl_mla_decode.py
  • tests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
  • flashinfer/cute_dsl/attention/mla_decode.py
  • flashinfer/cute_dsl/attention/mla_config.py
  • flashinfer/cute_dsl/attention/mla_decode_fp8.py

@saltyminty saltyminty force-pushed the fix/mingyangw/support-kimi-k2-5-config-for-cutadsl-mla-decode branch from a925da9 to 67a9838 Compare May 7, 2026 20:59
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/cute_dsl/attention/mla_config.py`:
- Line 177: The predicate currently allows any H <= 128 which admits unsupported
head counts like 96; change the checks that reference the variable H (at the
predicates around lines with H > 128 and the other at line ~221) to only accept
H values of 64 or 128 (e.g., replace the broad check with a guard that fails
when H not in (64, 128)). Update both occurrences (the predicate using H and the
second check at line ~221) to explicitly enforce H == 64 or H == 128 and
return/raise the same failure path used elsewhere so downstream kernel/layout
assumptions cannot be reached with unsupported H.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 742e8c61-6a4d-4d70-abe7-dbe502b31ccb

📥 Commits

Reviewing files that changed from the base of the PR and between a925da9d071ebe27d0b73c2eb89ae687d7d3e712 and 67a9838.

📒 Files selected for processing (7)
  • flashinfer/cute_dsl/attention/mla_config.py
  • flashinfer/cute_dsl/attention/mla_decode.py
  • flashinfer/cute_dsl/attention/mla_decode_fp8.py
  • flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
  • tests/attention/test_cute_dsl_mla_decode.py
  • tests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
  • flashinfer/cute_dsl/attention/wrappers/batch_mla.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • tests/attention/test_trtllm_gen_mla.py
  • flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
  • flashinfer/cute_dsl/attention/mla_decode.py
  • flashinfer/cute_dsl/attention/mla_decode_fp8.py
  • tests/attention/test_cute_dsl_mla_decode.py

Comment thread flashinfer/cute_dsl/attention/mla_config.py
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !631 has been updated with latest changes, and the CI pipeline #50601021 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

saltyminty added 3 commits May 8, 2026 10:44
CuTeDSL MLA decode uses a 128-wide physical head tile, so H=64 split-KV needs partial-output workspace pitched to that physical width before reduction. Keep logical output and reduction over H while allocating/pitching scratch with max(H, 128).

Constraint: Existing SM100 CuTeDSL MLA decode config uses 128-wide QK/PV M tiles.
Rejected: Leave split_kv disabled for H=64 | Kimi K2.5 needs the split-KV path.
Confidence: medium
Scope-risk: moderate
Directive: Do not shrink split-KV workspace pitch below the physical MMA M tile without adding a true smaller-M kernel specialization.
Tested: python3 -m py_compile on changed Python files; git diff --check; pre-commit hooks; remote SM100 public API smoke with H=64 and split_kv=32 before final expression cleanup.
Not-tested: Full pytest sweep; remote smoke after replacing literal 128 with cutlass.max(H, 128).
…plit_kv from can_implement

Address review feedback by strengthening the H64 regression coverage and removing split_kv from the capability-check signatures now that split_kv no longer gates support eligibility.

Constraint: Keep runtime split_kv handling unchanged; this only affects can_implement eligibility checks.
Confidence: high
Scope-risk: narrow
Tested: pre-commit run --files flashinfer/cute_dsl/attention/mla_config.py flashinfer/cute_dsl/attention/mla_decode.py flashinfer/cute_dsl/attention/mla_decode_fp8.py flashinfer/cute_dsl/attention/wrappers/batch_mla.py tests/attention/test_cute_dsl_mla_decode.py
Tested: remote SM100 focused pytest, 3 passed in 11.09s, log /home/scratch.mingyangw_gpu/flashinfer-3161-validation/logs/pr3235-comment-fixes-focused-v2.log
Not-tested: full test_cute_dsl_mla_decode.py rerun after review-feedback patch
Add the missing finite-output guard before the FP8 ALiBi MLA decode reference comparison so non-finite values fail explicitly before tolerance checks.

Constraint: Preserve the review-requested commit subject.
Confidence: high
Scope-risk: narrow
Tested: pre-commit run --files tests/attention/test_cute_dsl_mla_decode.py
Tested: python3 -m py_compile tests/attention/test_cute_dsl_mla_decode.py
Tested: remote SM100 focused pytest, 1 passed in 5.36s, log /home/scratch.mingyangw_gpu/flashinfer-3161-validation/logs/pr3235-alibi-finite-guard.log
Not-tested: full test_cute_dsl_mla_decode.py rerun after this one-line guard
@saltyminty saltyminty force-pushed the fix/mingyangw/support-kimi-k2-5-config-for-cutadsl-mla-decode branch from 67a9838 to b4d193c Compare May 8, 2026 17:44
@saltyminty
Copy link
Copy Markdown
Collaborator Author

CI failures all look unrelated.

Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Seems like a 64 wide MMA-M implementation is non-trivial?

@saltyminty saltyminty merged commit 2d0e0ef into main May 11, 2026
52 of 72 checks passed
@saltyminty saltyminty deleted the fix/mingyangw/support-kimi-k2-5-config-for-cutadsl-mla-decode branch May 11, 2026 17:00
pgera added a commit to pgera/flashinfer that referenced this pull request May 19, 2026
…te_dsl_impl= (AI-assisted)

PR flashinfer-ai#2805 refactored the monolithic CuTe-DSL MLA decode kernel into a
modular structure and removed the original implementation.  The original
authors want it kept available because the modular path is still
maturing.  Restore it under the cute-dsl backend (no new backend name)
and let the user pick:

  flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
      ..., backend="cute-dsl", cute_dsl_impl="auto" | "modular" | "monolithic")

Layout:
  flashinfer/cute_dsl/attention/
    monolithic/         - restored kernels (verbatim from before flashinfer-ai#2805,
                          relocated to live next to the modular code).
                          Includes the H<128 / Kimi K2.5 fix from flashinfer-ai#3235
                          backported (workspace pads H to max(H, 128)
                          when split_kv != 1; can_implement no longer
                          rejects H<128).
    wrappers/           - existing modular standalone + wrapper.
    mla_dispatch.py     - new dispatcher in front of both impls.

Dispatcher contract:
  - "auto"       (default): monolithic, but auto-promotes to modular
                            when a modular-only feature is requested
                            (currently sinks).
  - "modular"    : strict, always modular.
  - "monolithic" : strict, raises ValueError if a modular-only feature
                   is requested rather than silently substituting.

The dispatcher strips modular-only kwargs (sinks=None) before
forwarding to monolithic, so callers can pass sinks= unconditionally
without breaking the monolithic path.

Sinks support:
  - trtllm_batch_decode_with_kv_cache_mla(sinks=...) on backend=
    "cute-dsl" now constructs an AttentionWithSink variant inside the
    modular standalone, instead of being rejected at the API boundary.
  - AttentionWithSink gains value-based __hash__/__eq__ keyed on
    (data_ptr, shape, dtype) so @functools.cache on _compile_mla_kernel
    correctly reuses compiled kernels across invocations with the same
    sinks tensor.  Without this, a fresh variant per call hashed by
    object identity, JIT-recompiled the kernel on every iteration, and
    made cuda-graph + sinks bench measurements appear to hang.

Tests:
  - test_cute_dsl_mla_decode.py: existing standalone and public-API
    tests now parametrize over modular/monolithic via a cute_dsl_impl
    fixture; new minimal sinks tests pin the auto/modular dispatch
    branches and the monolithic+sinks ValueError contract.  Wrapper
    sinks numerics remain covered by the pre-existing
    test_cute_dsl_mla_decode_attention_sink.
  - test_trtllm_gen_mla.py: comment near the cute-dsl skip refreshed
    to reflect the dispatcher's cute_dsl_impl behaviour.

Bench:
  - bench_trtllm_gen_mla.py grows a focused 6-cell with_sinks=True
    sub-sweep (B in {1,16,128} x S in {1024,8192} at q_len=1, page=64,
    bf16) on top of the existing main sweep, instead of doubling the
    full grid.  Argument list deduplicated into a common_kwargs dict
    so warmup and benchmark calls cannot drift.
saltyminty pushed a commit that referenced this pull request May 21, 2026
…te_dsl_impl= (#3296)

## Summary

PR #2805 refactored the monolithic CuTe-DSL MLA decode kernel into a
modular structure and removed the original implementation. The original
authors want it kept available because the modular path is still
maturing. This PR restores it under the existing
\`backend=\"cute-dsl\"\` user surface (no new backend name) and exposes
implementation selection via a new \`cute_dsl_impl=\` keyword argument
on \`trtllm_batch_decode_with_kv_cache_mla\`.

- **\"auto\"** (default): monolithic by default, automatically promoted
to modular when the call uses a modular-only feature (currently
\`sinks\`).
- **\"modular\"**: strict, always run the modular kernels.
- **\"monolithic\"**: strict, always run the monolithic kernels; raises
\`ValueError\` if the call uses any modular-only feature.

The dispatcher strips modular-only kwargs (\`sinks=None\`) before
forwarding to monolithic, so callers can pass \`sinks=\` unconditionally
without breaking the monolithic path.

### Sinks support on cute-dsl backend

\`trtllm_batch_decode_with_kv_cache_mla(sinks=...)\` on
\`backend=\"cute-dsl\"\` now constructs an \`AttentionWithSink\` variant
inside the modular standalone, instead of being rejected at the API
boundary. \`AttentionWithSink\` gained value-based
\`__hash__\`/\`__eq__\` (keyed on \`(type, shape, dtype)\`) so
\`@functools.cache\` on \`_compile_mla_kernel\` correctly reuses
compiled kernels across invocations with the same shape — without this,
a fresh variant per call hashed by object identity, JIT-recompiled the
kernel on every iteration, and made cuda-graph + sinks bench
measurements appear to hang.

### Layout

\`\`\`
flashinfer/cute_dsl/attention/
  monolithic/         - restored kernels (verbatim from before #2805,
                        relocated to live next to the modular code).
                        Includes the H<128 / Kimi K2.5 fix from #3235
                        backported.
  wrappers/           - existing modular standalone + wrapper.
  mla_dispatch.py     - new dispatcher in front of both impls.
\`\`\`

### Bench

\`benchmarks/bench_trtllm_gen_mla.py\` grows a focused 6-cell
\`with_sinks=True\` sub-sweep (B in {1,16,128} × S in {1024,8192} at
q_len=1, page=64, bf16) on top of the existing main sweep, instead of
doubling the full grid. Argument list deduplicated into a
\`common_kwargs\` dict so warmup and benchmark calls cannot drift.

Sinks overhead is ~free on both backends (worst case +1.9% at the
smallest cell). Cross-backend ranking does not change with sinks
enabled.

## Test plan

Existing standalone and public-API tests in
\`tests/attention/test_cute_dsl_mla_decode.py\` now parametrize over
modular/monolithic via a \`cute_dsl_impl\` fixture, doubling coverage on
the same shapes. New minimal sinks tests pin the auto/modular dispatch
branches and the monolithic+sinks \`ValueError\` contract. Wrapper sinks
numerics remain covered by the pre-existing
\`test_cute_dsl_mla_decode_attention_sink\`.

- [x] All pre-commit hooks pass on changed files (mypy, ruff check, ruff
format, EOF, whitespace, etc.).
- [x] \`pytest tests/attention/test_cute_dsl_mla_decode.py -v\` — full
sweep (544 cases incl. parametrized modular/monolithic) passes on B200.
- [x] \`pytest tests/attention/test_cute_dsl_mla_decode.py -k sinks\` —
3 new sinks integration tests pass.
- [x] \`pytest tests/attention/test_trtllm_gen_mla.py -v\` — unaffected,
passes.
- [x] H=64 / Kimi K2.5 backport on monolithic exercised via the existing
\`num_heads ∈ [128, 64]\` parametrization ×
\`cute_dsl_impl=monolithic\`.
- [x] Bench \`benchmarks/bench_trtllm_gen_mla.py --backend cute-dsl\`
and \`--backend trtllm-gen\` both run cleanly through the focused sinks
sub-sweep.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional "sinks" support for an alternate softmax path and a
cute_dsl_impl option to choose/auto-select modular vs monolithic MLA
decode implementations.
  * New monolithic CuTe-based MLA kernel targeting Blackwell hardware.

* **Performance / Reliability**
* Improved kernel caching/variant keying to enable reuse across variant
instances.
* Benchmark updated to exercise sinks-enabled and sinks-disabled paths.

* **Tests**
* Added tests for implementation selection, sinks behavior, and related
error/shape validations.

<!-- review_stack_entry_start -->

[![Review Change
Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3296?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack)

<!-- review_stack_entry_end -->
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants