Support Kimi K2.5 H64 CuTe DSL MLA decode#3235
Conversation
|
ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (7)
💤 Files with no reviewable changes (1)
🚧 Files skipped from review as they are similar to previous changes (6)
📝 WalkthroughWalkthroughThis PR relaxes MLA decode gating for small heads by removing (H < 128 and split_kv != 1) checks, padding workspace head dimension to 128 for split-KV, updating workspace layouts/strides and workspace-size calculations, removing runtime small-H rejections, and extending tests to cover 64-head cases. ChangesMLA Small Head Dimension Support
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
flashinfer/cute_dsl/attention/mla_config.py (1)
142-184: 💤 Low value
split_kvparameter is now unused incan_implement/can_implement_fp8.After dropping the small-H gate, neither overload references
split_kv. Either re-introduce a check that depends on it (e.g., upper bound vs.MAX_SPLITS/K) or annotate the parameter as intentionally unused so it doesn't read as dead. Keeping the signature stable is fine.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/cute_dsl/attention/mla_config.py` around lines 142 - 184, The split_kv parameter is currently unused in can_implement and can_implement_fp8; update these functions to either (A) enforce a meaningful constraint using split_kv (for example compare split_kv against a MAX_SPLITS constant or derive an upper bound from K) by adding that check into can_implement / can_implement_fp8, or (B) explicitly mark split_kv as intentionally unused to silence dead-parameter concerns (e.g., rename to _split_kv or add an explicit unused annotation/comment/cast) while keeping the public signature stable; apply the same choice consistently to both can_implement and can_implement_fp8.tests/attention/test_cute_dsl_mla_decode.py (1)
309-326: ⚡ Quick winAdd a NaN check to make this a useful regression for the H=64 padding path.
The PR's smoke-test acceptance criterion for H=64 was "no NaNs", but the test only asserts the output shape. Without a finiteness check, a regression where padded lanes leak into the reduction (the exact failure mode this PR's
workspace_H = max(H, 128)fix prevents) would still pass.♻️ Proposed addition
assert out.shape == (batch_size, q_len, num_heads, latent_dim) + assert not torch.isnan(out).any(), "cute-dsl MLA decode produced NaNs"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 309 - 326, The test currently only checks out.shape after calling trtllm_batch_decode_with_kv_cache_mla, so a regression that produces NaNs or Infs (the H=64 padding bug) would still pass; add a finiteness check right after the call that asserts all elements of out are finite (no NaN/Inf) — reference the call to trtllm_batch_decode_with_kv_cache_mla and the resulting variable out and replace/augment the existing shape assertion with an additional assertion using the test framework's numeric/torch/isfinite check to fail if any element is not finite and include a clear message.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@flashinfer/cute_dsl/attention/mla_config.py`:
- Around line 142-184: The split_kv parameter is currently unused in
can_implement and can_implement_fp8; update these functions to either (A)
enforce a meaningful constraint using split_kv (for example compare split_kv
against a MAX_SPLITS constant or derive an upper bound from K) by adding that
check into can_implement / can_implement_fp8, or (B) explicitly mark split_kv as
intentionally unused to silence dead-parameter concerns (e.g., rename to
_split_kv or add an explicit unused annotation/comment/cast) while keeping the
public signature stable; apply the same choice consistently to both
can_implement and can_implement_fp8.
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 309-326: The test currently only checks out.shape after calling
trtllm_batch_decode_with_kv_cache_mla, so a regression that produces NaNs or
Infs (the H=64 padding bug) would still pass; add a finiteness check right after
the call that asserts all elements of out are finite (no NaN/Inf) — reference
the call to trtllm_batch_decode_with_kv_cache_mla and the resulting variable out
and replace/augment the existing shape assertion with an additional assertion
using the test framework's numeric/torch/isfinite check to fail if any element
is not finite and include a clear message.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f9f0d24a-e636-4005-9f79-05ca826cbdc8
📒 Files selected for processing (7)
flashinfer/cute_dsl/attention/mla_config.pyflashinfer/cute_dsl/attention/mla_decode.pyflashinfer/cute_dsl/attention/mla_decode_fp8.pyflashinfer/cute_dsl/attention/scheduler/mla_persistent.pyflashinfer/cute_dsl/attention/wrappers/batch_mla.pytests/attention/test_cute_dsl_mla_decode.pytests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
- flashinfer/cute_dsl/attention/wrappers/batch_mla.py
There was a problem hiding this comment.
Code Review
This pull request enables support for Multi-Head Latent Attention (MLA) with fewer than 128 heads in the cute-dsl backend by padding the workspace to a minimum of 128 heads. This ensures compatibility with the 128-wide MMA-M tile used in the kernels. The changes update workspace initialization, layout calculations, and size estimation for both FP16 and FP8 variants, while also removing outdated runtime validation checks and expanding test coverage. I have no feedback to provide as the existing review comments were purely explanatory and did not identify any issues.
|
/bot run |
52e4c69 to
a925da9
Compare
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)
976-1044:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winMissing
torch.isfiniteassertion — inconsistent with the other two updated tests.
test_cute_dsl_mla_decode_via_api(line 327) andtest_cute_dsl_mla_decode_fp8(line 463) both received an explicittorch.isfinite(out).all()guard in this PR, buttest_cute_dsl_mla_decode_fp8_alibidid not. Add the check before theassert_closecall to keep the pattern uniform.🛡️ Proposed fix
out = wrapper.run(...) + assert torch.isfinite(out).all(), "FP8 ALiBi cute-dsl MLA decode produced non-finite values" + D_qk = latent_dim + rope_dim🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 976 - 1044, Add a finite-value guard before the assertion in test_cute_dsl_mla_decode_fp8_alibi: verify torch.isfinite(out).all() (same pattern used in test_cute_dsl_mla_decode_via_api and test_cute_dsl_mla_decode_fp8) immediately before calling torch.testing.assert_close on out so the test checks for non-NaN/inf values prior to comparison.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 976-1044: Add a finite-value guard before the assertion in
test_cute_dsl_mla_decode_fp8_alibi: verify torch.isfinite(out).all() (same
pattern used in test_cute_dsl_mla_decode_via_api and
test_cute_dsl_mla_decode_fp8) immediately before calling
torch.testing.assert_close on out so the test checks for non-NaN/inf values
prior to comparison.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f2cc339a-c384-474a-9ce5-739266bcfac1
📥 Commits
Reviewing files that changed from the base of the PR and between 7f3bbed and a925da9d071ebe27d0b73c2eb89ae687d7d3e712.
📒 Files selected for processing (7)
flashinfer/cute_dsl/attention/mla_config.pyflashinfer/cute_dsl/attention/mla_decode.pyflashinfer/cute_dsl/attention/mla_decode_fp8.pyflashinfer/cute_dsl/attention/scheduler/mla_persistent.pyflashinfer/cute_dsl/attention/wrappers/batch_mla.pytests/attention/test_cute_dsl_mla_decode.pytests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
- flashinfer/cute_dsl/attention/wrappers/batch_mla.py
🚧 Files skipped from review as they are similar to previous changes (4)
- flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
- flashinfer/cute_dsl/attention/mla_decode.py
- flashinfer/cute_dsl/attention/mla_config.py
- flashinfer/cute_dsl/attention/mla_decode_fp8.py
a925da9 to
67a9838
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/cute_dsl/attention/mla_config.py`:
- Line 177: The predicate currently allows any H <= 128 which admits unsupported
head counts like 96; change the checks that reference the variable H (at the
predicates around lines with H > 128 and the other at line ~221) to only accept
H values of 64 or 128 (e.g., replace the broad check with a guard that fails
when H not in (64, 128)). Update both occurrences (the predicate using H and the
second check at line ~221) to explicitly enforce H == 64 or H == 128 and
return/raise the same failure path used elsewhere so downstream kernel/layout
assumptions cannot be reached with unsupported H.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 742e8c61-6a4d-4d70-abe7-dbe502b31ccb
📥 Commits
Reviewing files that changed from the base of the PR and between a925da9d071ebe27d0b73c2eb89ae687d7d3e712 and 67a9838.
📒 Files selected for processing (7)
flashinfer/cute_dsl/attention/mla_config.pyflashinfer/cute_dsl/attention/mla_decode.pyflashinfer/cute_dsl/attention/mla_decode_fp8.pyflashinfer/cute_dsl/attention/scheduler/mla_persistent.pyflashinfer/cute_dsl/attention/wrappers/batch_mla.pytests/attention/test_cute_dsl_mla_decode.pytests/attention/test_trtllm_gen_mla.py
💤 Files with no reviewable changes (1)
- flashinfer/cute_dsl/attention/wrappers/batch_mla.py
🚧 Files skipped from review as they are similar to previous changes (5)
- tests/attention/test_trtllm_gen_mla.py
- flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
- flashinfer/cute_dsl/attention/mla_decode.py
- flashinfer/cute_dsl/attention/mla_decode_fp8.py
- tests/attention/test_cute_dsl_mla_decode.py
|
/bot run |
CuTeDSL MLA decode uses a 128-wide physical head tile, so H=64 split-KV needs partial-output workspace pitched to that physical width before reduction. Keep logical output and reduction over H while allocating/pitching scratch with max(H, 128). Constraint: Existing SM100 CuTeDSL MLA decode config uses 128-wide QK/PV M tiles. Rejected: Leave split_kv disabled for H=64 | Kimi K2.5 needs the split-KV path. Confidence: medium Scope-risk: moderate Directive: Do not shrink split-KV workspace pitch below the physical MMA M tile without adding a true smaller-M kernel specialization. Tested: python3 -m py_compile on changed Python files; git diff --check; pre-commit hooks; remote SM100 public API smoke with H=64 and split_kv=32 before final expression cleanup. Not-tested: Full pytest sweep; remote smoke after replacing literal 128 with cutlass.max(H, 128).
…plit_kv from can_implement Address review feedback by strengthening the H64 regression coverage and removing split_kv from the capability-check signatures now that split_kv no longer gates support eligibility. Constraint: Keep runtime split_kv handling unchanged; this only affects can_implement eligibility checks. Confidence: high Scope-risk: narrow Tested: pre-commit run --files flashinfer/cute_dsl/attention/mla_config.py flashinfer/cute_dsl/attention/mla_decode.py flashinfer/cute_dsl/attention/mla_decode_fp8.py flashinfer/cute_dsl/attention/wrappers/batch_mla.py tests/attention/test_cute_dsl_mla_decode.py Tested: remote SM100 focused pytest, 3 passed in 11.09s, log /home/scratch.mingyangw_gpu/flashinfer-3161-validation/logs/pr3235-comment-fixes-focused-v2.log Not-tested: full test_cute_dsl_mla_decode.py rerun after review-feedback patch
Add the missing finite-output guard before the FP8 ALiBi MLA decode reference comparison so non-finite values fail explicitly before tolerance checks. Constraint: Preserve the review-requested commit subject. Confidence: high Scope-risk: narrow Tested: pre-commit run --files tests/attention/test_cute_dsl_mla_decode.py Tested: python3 -m py_compile tests/attention/test_cute_dsl_mla_decode.py Tested: remote SM100 focused pytest, 1 passed in 5.36s, log /home/scratch.mingyangw_gpu/flashinfer-3161-validation/logs/pr3235-alibi-finite-guard.log Not-tested: full test_cute_dsl_mla_decode.py rerun after this one-line guard
67a9838 to
b4d193c
Compare
|
CI failures all look unrelated. |
jimmyzho
left a comment
There was a problem hiding this comment.
lgtm. Seems like a 64 wide MMA-M implementation is non-trivial?
…te_dsl_impl= (AI-assisted) PR flashinfer-ai#2805 refactored the monolithic CuTe-DSL MLA decode kernel into a modular structure and removed the original implementation. The original authors want it kept available because the modular path is still maturing. Restore it under the cute-dsl backend (no new backend name) and let the user pick: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( ..., backend="cute-dsl", cute_dsl_impl="auto" | "modular" | "monolithic") Layout: flashinfer/cute_dsl/attention/ monolithic/ - restored kernels (verbatim from before flashinfer-ai#2805, relocated to live next to the modular code). Includes the H<128 / Kimi K2.5 fix from flashinfer-ai#3235 backported (workspace pads H to max(H, 128) when split_kv != 1; can_implement no longer rejects H<128). wrappers/ - existing modular standalone + wrapper. mla_dispatch.py - new dispatcher in front of both impls. Dispatcher contract: - "auto" (default): monolithic, but auto-promotes to modular when a modular-only feature is requested (currently sinks). - "modular" : strict, always modular. - "monolithic" : strict, raises ValueError if a modular-only feature is requested rather than silently substituting. The dispatcher strips modular-only kwargs (sinks=None) before forwarding to monolithic, so callers can pass sinks= unconditionally without breaking the monolithic path. Sinks support: - trtllm_batch_decode_with_kv_cache_mla(sinks=...) on backend= "cute-dsl" now constructs an AttentionWithSink variant inside the modular standalone, instead of being rejected at the API boundary. - AttentionWithSink gains value-based __hash__/__eq__ keyed on (data_ptr, shape, dtype) so @functools.cache on _compile_mla_kernel correctly reuses compiled kernels across invocations with the same sinks tensor. Without this, a fresh variant per call hashed by object identity, JIT-recompiled the kernel on every iteration, and made cuda-graph + sinks bench measurements appear to hang. Tests: - test_cute_dsl_mla_decode.py: existing standalone and public-API tests now parametrize over modular/monolithic via a cute_dsl_impl fixture; new minimal sinks tests pin the auto/modular dispatch branches and the monolithic+sinks ValueError contract. Wrapper sinks numerics remain covered by the pre-existing test_cute_dsl_mla_decode_attention_sink. - test_trtllm_gen_mla.py: comment near the cute-dsl skip refreshed to reflect the dispatcher's cute_dsl_impl behaviour. Bench: - bench_trtllm_gen_mla.py grows a focused 6-cell with_sinks=True sub-sweep (B in {1,16,128} x S in {1024,8192} at q_len=1, page=64, bf16) on top of the existing main sweep, instead of doubling the full grid. Argument list deduplicated into a common_kwargs dict so warmup and benchmark calls cannot drift.
…te_dsl_impl= (#3296) ## Summary PR #2805 refactored the monolithic CuTe-DSL MLA decode kernel into a modular structure and removed the original implementation. The original authors want it kept available because the modular path is still maturing. This PR restores it under the existing \`backend=\"cute-dsl\"\` user surface (no new backend name) and exposes implementation selection via a new \`cute_dsl_impl=\` keyword argument on \`trtllm_batch_decode_with_kv_cache_mla\`. - **\"auto\"** (default): monolithic by default, automatically promoted to modular when the call uses a modular-only feature (currently \`sinks\`). - **\"modular\"**: strict, always run the modular kernels. - **\"monolithic\"**: strict, always run the monolithic kernels; raises \`ValueError\` if the call uses any modular-only feature. The dispatcher strips modular-only kwargs (\`sinks=None\`) before forwarding to monolithic, so callers can pass \`sinks=\` unconditionally without breaking the monolithic path. ### Sinks support on cute-dsl backend \`trtllm_batch_decode_with_kv_cache_mla(sinks=...)\` on \`backend=\"cute-dsl\"\` now constructs an \`AttentionWithSink\` variant inside the modular standalone, instead of being rejected at the API boundary. \`AttentionWithSink\` gained value-based \`__hash__\`/\`__eq__\` (keyed on \`(type, shape, dtype)\`) so \`@functools.cache\` on \`_compile_mla_kernel\` correctly reuses compiled kernels across invocations with the same shape — without this, a fresh variant per call hashed by object identity, JIT-recompiled the kernel on every iteration, and made cuda-graph + sinks bench measurements appear to hang. ### Layout \`\`\` flashinfer/cute_dsl/attention/ monolithic/ - restored kernels (verbatim from before #2805, relocated to live next to the modular code). Includes the H<128 / Kimi K2.5 fix from #3235 backported. wrappers/ - existing modular standalone + wrapper. mla_dispatch.py - new dispatcher in front of both impls. \`\`\` ### Bench \`benchmarks/bench_trtllm_gen_mla.py\` grows a focused 6-cell \`with_sinks=True\` sub-sweep (B in {1,16,128} × S in {1024,8192} at q_len=1, page=64, bf16) on top of the existing main sweep, instead of doubling the full grid. Argument list deduplicated into a \`common_kwargs\` dict so warmup and benchmark calls cannot drift. Sinks overhead is ~free on both backends (worst case +1.9% at the smallest cell). Cross-backend ranking does not change with sinks enabled. ## Test plan Existing standalone and public-API tests in \`tests/attention/test_cute_dsl_mla_decode.py\` now parametrize over modular/monolithic via a \`cute_dsl_impl\` fixture, doubling coverage on the same shapes. New minimal sinks tests pin the auto/modular dispatch branches and the monolithic+sinks \`ValueError\` contract. Wrapper sinks numerics remain covered by the pre-existing \`test_cute_dsl_mla_decode_attention_sink\`. - [x] All pre-commit hooks pass on changed files (mypy, ruff check, ruff format, EOF, whitespace, etc.). - [x] \`pytest tests/attention/test_cute_dsl_mla_decode.py -v\` — full sweep (544 cases incl. parametrized modular/monolithic) passes on B200. - [x] \`pytest tests/attention/test_cute_dsl_mla_decode.py -k sinks\` — 3 new sinks integration tests pass. - [x] \`pytest tests/attention/test_trtllm_gen_mla.py -v\` — unaffected, passes. - [x] H=64 / Kimi K2.5 backport on monolithic exercised via the existing \`num_heads ∈ [128, 64]\` parametrization × \`cute_dsl_impl=monolithic\`. - [x] Bench \`benchmarks/bench_trtllm_gen_mla.py --backend cute-dsl\` and \`--backend trtllm-gen\` both run cleanly through the focused sinks sub-sweep. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional "sinks" support for an alternate softmax path and a cute_dsl_impl option to choose/auto-select modular vs monolithic MLA decode implementations. * New monolithic CuTe-based MLA kernel targeting Blackwell hardware. * **Performance / Reliability** * Improved kernel caching/variant keying to enable reuse across variant instances. * Benchmark updated to exercise sinks-enabled and sinks-disabled paths. * **Tests** * Added tests for implementation selection, sinks behavior, and related error/shape validations. <!-- review_stack_entry_start --> [](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3296?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
#3161
Summary
Notes
An experimental 64-wide MMA-M path was investigated separately, but it was reverted and is not part of this PR. The branch intentionally keeps the GPU-validated padded implementation.
Validation
pre-commit run --files flashinfer/cute_dsl/attention/collective_builder.py flashinfer/cute_dsl/attention/mla_config.py flashinfer/cute_dsl/attention/mla_decode.py flashinfer/cute_dsl/attention/mla_decode_fp8.py flashinfer/cute_dsl/attention/scheduler/mla_persistent.py flashinfer/cute_dsl/attention/wrappers/batch_mla.py tests/attention/test_cute_dsl_mla_decode.pysplit_kv=32, output shape(1, 1, 64, 512), dtypetorch.float16, no NaNs.0.027255 ms, B=1/S=5120.027000 ms, B=4/S=1280.026925 ms, B=4/S=5120.026979 ms.Summary by CodeRabbit
New Features
Bug Fixes
Tests