Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe PR extends FMHA reduction kernel support for variable head dimensions, introduces an MLADimensions data structure for structured MLA configuration handling, expands kernel selection to accept both (576, 512) and (320, 256) head dimension pairs, and parametrizes tests to validate multiple MLA dimension configurations. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ❌ 3❌ Failed checks (2 warnings, 1 inconclusive)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @hypdeb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and robustness of the FlashInfer library by extending support for Multi-Layer Attention (MLA) kernels to include additional head dimension configurations. It refactors the kernel selection mechanism to be more dynamic, abstracts MLA dimensions into a dedicated Python dataclass for better management, and updates the underlying compiled CUDA binaries. These changes enable the library to accommodate a wider range of model architectures and improve the overall maintainability and testability of the MLA implementation. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for additional MLA dimensions, specifically for a smaller configuration (qk=320, v=256). The changes are propagated through the CUDA kernels, launchers, and Python wrappers. The test suite is also updated to cover these new dimensions.
My review has identified a few critical issues. In csrc/fmhaReduction.cu, the kernel selection logic is missing default cases, which could lead to a crash if unsupported dimensions are used. In include/flashinfer/trtllm/fmha/fmhaKernels.cuh, a debugging macro is used incorrectly, which would result in uninformative error messages. I've also pointed out some minor code cleanup in flashinfer/artifacts.py. Overall, the changes are in the right direction but the critical issues should be addressed before merging.
csrc/fmhaReduction.cu
Outdated
| if (kernelMeta.mHeadDimV == 512) { | ||
| if (headDimPerCtaV == 128) { | ||
| SELECT_FMHA_REDUCTION_KERNEL(512, 128); | ||
| } else if (headDimPerCtaV == 256) { | ||
| SELECT_FMHA_REDUCTION_KERNEL(512, 256); | ||
| } else if (headDimPerCtaV == 512) { | ||
| SELECT_FMHA_REDUCTION_KERNEL(512, 512); | ||
| } | ||
| } else if (kernelMeta.mHeadDimV == 256) { | ||
| if (headDimPerCtaV == 128) { | ||
| SELECT_FMHA_REDUCTION_KERNEL(256, 128); | ||
| } else if (headDimPerCtaV == 256) { | ||
| SELECT_FMHA_REDUCTION_KERNEL(256, 256); | ||
| } | ||
| } |
There was a problem hiding this comment.
The kernel selection logic is missing default cases. If kernelMeta.mHeadDimV is not 512 or 256, or if headDimPerCtaV has an unexpected value for a given mHeadDimV, the kernel pointer will remain nullptr, leading to a crash when cudaLaunchKernelEx is called. It's safer to add else blocks with FLASHINFER_CHECK(false, ...) to handle unsupported dimension combinations explicitly.
if (kernelMeta.mHeadDimV == 512) {
if (headDimPerCtaV == 128) {
SELECT_FMHA_REDUCTION_KERNEL(512, 128);
} else if (headDimPerCtaV == 256) {
SELECT_FMHA_REDUCTION_KERNEL(512, 256);
} else if (headDimPerCtaV == 512) {
SELECT_FMHA_REDUCTION_KERNEL(512, 512);
} else {
FLASHINFER_CHECK(false, "Unsupported headDimPerCtaV for HeadDimV=512");
}
} else if (kernelMeta.mHeadDimV == 256) {
if (headDimPerCtaV == 128) {
SELECT_FMHA_REDUCTION_KERNEL(256, 128);
} else if (headDimPerCtaV == 256) {
SELECT_FMHA_REDUCTION_KERNEL(256, 256);
} else {
FLASHINFER_CHECK(false, "Unsupported headDimPerCtaV for HeadDimV=256");
}
} else {
FLASHINFER_CHECK(false, "Unsupported HeadDimV");
}
| FLASHINFER_CHECK(isFamilySpecificSMPair(existingKernelMeta.mSM, kernelMeta.mSM), | ||
| "Hash conflicts exist between %s and %s.", existingKernelMeta.mFuncName, | ||
| kernelMeta.mFuncName); | ||
| "Hash conflicts exist between:\n existing: ", | ||
| kernelMetaToString(existingKernelMeta), | ||
| "\n new: ", | ||
| kernelMetaToString(kernelMeta)); |
There was a problem hiding this comment.
The usage of FLASHINFER_CHECK seems incorrect here. This macro appears to be printf-style, but the new code provides multiple string arguments without format specifiers. This will likely result in an uninformative error message, as the additional arguments will be ignored. The error message should be a single format string with arguments, and std::string objects should be converted to const char* using .c_str().
FLASHINFER_CHECK(isFamilySpecificSMPair(existingKernelMeta.mSM, kernelMeta.mSM),
"Hash conflicts exist between:\n existing: %s\n new: %s",
kernelMetaToString(existingKernelMeta).c_str(),
kernelMetaToString(kernelMeta).c_str());
|
|
||
| TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/" | ||
| TRTLLM_GEN_FMHA: str = "e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/" | ||
| # TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/" |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tests/attention/test_trtllm_gen_mla.py (1)
316-316:⚠️ Potential issue | 🟡 MinorHardcoded
(128 + 64)scale denominator ignoresattention_configforsmaller_mla_dimensionsBoth
bmm1_scale(line 316) andsm_scale(lines 326–327) use the hardcoded DeepSeek v3 head dim(128 + 64)instead ofattention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim. Forsmaller_mla_dimensions(qk_nope=64, qk_rope=64) the denominator should besqrt(128)notsqrt(192). Since both the kernel and the reference wrapper use the same value, the test comparison is still self-consistent, but the scale is semantically wrong and could mask actual numerical issues.🐛 Proposed fix
- bmm1_scale=scale / ((128 + 64) ** 0.5), + bmm1_scale=scale / ((attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) ** 0.5),- sm_scale = scale / ( - (128 + 64) ** 0.5 - ) # use head dimension before matrix absorption + sm_scale = scale / ( + (attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) ** 0.5 + ) # use head dimension before matrix absorption🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_mla.py` at line 316, Replace the hardcoded DeepSeek v3 head-dim sum used in bmm1_scale and sm_scale with the actual head-dimension sum from the test's attention_config (attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) so the denominator becomes sqrt(attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) (or equivalent) instead of sqrt(128 + 64); update both occurrences that set bmm1_scale and sm_scale in tests/attention/test_trtllm_gen_mla.py to compute the denominator from attention_config to correctly honor smaller_mla_dimensions.csrc/fmhaReduction.cu (1)
311-312:⚠️ Potential issue | 🟡 Minor
headDimPerCtaV==512is permitted by the guard but unhandled in themHeadDimV==256dispatch branch — null kernel dereference if reachedThe
FLASHINFER_CHECKat line 311 acceptsheadDimPerCtaV ∈ {128, 256, 512}unconditionally, but the dispatch tree formHeadDimV == 256(lines 363–368) only covers 128 and 256. IfheadDimPerCtaV == 512is somehow reached withmHeadDimV == 256,kernelstaysnullptrandcuLaunchKernelExat line 372 will crash. Adding a null guard and a tighter check will eliminate the risk:🛡️ Proposed defensive fix
FLASHINFER_CHECK(headDimPerCtaV == 128 || headDimPerCtaV == 256 || headDimPerCtaV == 512, "Not implemented");And after the dispatch block:
+ FLASHINFER_CHECK(kernel != nullptr, + "Unsupported (mHeadDimV=%d, headDimPerCtaV=%d) combination in fmhaReduction", + kernelMeta.mHeadDimV, headDimPerCtaV); // Launch the kernel. cudaLaunchKernelEx(&config, kernel, params, ...);Also applies to: 363-372
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/fmhaReduction.cu` around lines 311 - 312, The guard allows headDimPerCtaV==512 but the mHeadDimV==256 dispatch branch only sets kernel for 128 and 256, risking a null kernel and crash at cuLaunchKernelEx; update the dispatch in the mHeadDimV==256 case (the branch that inspects headDimPerCtaV and assigns kernel) to either handle the 512 case properly or tighten the initial FLASHINFER_CHECK to forbid 512 when mHeadDimV==256, and add a defensive null check before cuLaunchKernelEx (verify kernel != nullptr and return/log an error) to prevent dereference if an unsupported headDimPerCtaV slips through.flashinfer/mla.py (2)
562-566:⚠️ Potential issue | 🟡 MinorStale docstring — dimension constraints no longer accurate
The docstring claims
qk_nope_head_dim"must be 128",kv_lora_rank"must be 512", andqk_rope_head_dim"must be 64", but the function now supports bothdeepseek_mla_dimensionsandsmaller_mla_dimensions(64/256/64). Update to referencesupported_mla_dimensionsor list both accepted configurations.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla.py` around lines 562 - 566, Update the stale docstring that currently asserts fixed sizes for qk_nope_head_dim, kv_lora_rank, and qk_rope_head_dim — replace the hard "must be X" text with a statement that the function accepts the predefined configurations (e.g., deepseek_mla_dimensions and smaller_mla_dimensions) or a single reference to supported_mla_dimensions; mention both accepted tuples (64/256/64 and 128/512/64) or point readers to supported_mla_dimensions so the docstring no longer claims strict single values for qk_nope_head_dim, kv_lora_rank, and qk_rope_head_dim.
785-794:⚠️ Potential issue | 🔴 Critical
_check_trtllm_gen_mla_shapecalled with the old 8-arg signature — breaks CI
xqa_batch_decode_with_kv_cache_mlawas not updated when_check_trtllm_gen_mla_shapewas refactored from(query, kv_cache, qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, sparse_mla_top_k, page_table, page_size)to(query, kv_cache, mla_dimensions: MLADimensions, sparse_mla_top_k, page_table, page_size). The CI pipeline confirms this with "Too many arguments" and "Argument 3 has incompatible typeint; expectedMLADimensions".🐛 Proposed fix
- kv_cache = _check_trtllm_gen_mla_shape( - query, - kv_cache, - qk_nope_head_dim, - kv_lora_rank, - qk_rope_head_dim, - 0, # sparse_mla_top_k - block_tables, - block_size, - ) + kv_cache = _check_trtllm_gen_mla_shape( + query, + kv_cache, + MLADimensions( + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + kv_lora_rank=kv_lora_rank, + ), + 0, # sparse_mla_top_k + block_tables, + block_size, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla.py` around lines 785 - 794, The call to _check_trtllm_gen_mla_shape in xqa_batch_decode_with_kv_cache_mla uses the old 8-arg signature; construct an MLADimensions instance from qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim (e.g., mla_dimensions = MLADimensions(...)) and call _check_trtllm_gen_mla_shape(query, kv_cache, mla_dimensions, 0 /* sparse_mla_top_k */, block_tables, block_size) instead of passing the three ints; update the call site in xqa_batch_decode_with_kv_cache_mla to use the new 6-arg signature and remove the old 8-argument ordering.
🧹 Nitpick comments (1)
flashinfer/mla.py (1)
554-554:bool | Nonemixes union-syntax styles with the rest of the file
enable_pdl: bool | None = None(Python 3.10+ PEP 604 syntax) is inconsistent with the rest ofmla.pywhich imports and usesOptionalfromtyping. UseOptional[bool]for consistency.- enable_pdl: bool | None = None, + enable_pdl: Optional[bool] = None,Same applies to
xqa_batch_decode_with_kv_cache_mlaat line 734.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla.py` at line 554, The parameter type annotations use PEP 604 union syntax ("bool | None") which is inconsistent with the rest of the file; change the annotations for enable_pdl in the function signature that declares "enable_pdl: bool | None = None" and for xqa_batch_decode_with_kv_cache_mla's corresponding parameter to use Optional[bool] instead (i.e., "enable_pdl: Optional[bool] = None"), and ensure Optional is imported from typing if not already used in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/artifacts.py`:
- Around line 90-91: Update the checksum constant to match the new ArtifactPath:
compute the sha256 of the checksums.txt at the new ArtifactPath.TRTLLM_GEN_FMHA
(currently "e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/") and
replace CheckSumHash.TRTLLM_GEN_FMHA's old value (e86f0e45…) with that new hash
so get_subdir_file_list → verify_cubin will validate correctly; if you
intentionally cannot update the hash yet, add a clear TODO comment on
CheckSumHash.TRTLLM_GEN_FMHA like “# TODO: update after cubin publish” to make
the deferred obligation explicit.
---
Outside diff comments:
In `@csrc/fmhaReduction.cu`:
- Around line 311-312: The guard allows headDimPerCtaV==512 but the
mHeadDimV==256 dispatch branch only sets kernel for 128 and 256, risking a null
kernel and crash at cuLaunchKernelEx; update the dispatch in the mHeadDimV==256
case (the branch that inspects headDimPerCtaV and assigns kernel) to either
handle the 512 case properly or tighten the initial FLASHINFER_CHECK to forbid
512 when mHeadDimV==256, and add a defensive null check before cuLaunchKernelEx
(verify kernel != nullptr and return/log an error) to prevent dereference if an
unsupported headDimPerCtaV slips through.
In `@flashinfer/mla.py`:
- Around line 562-566: Update the stale docstring that currently asserts fixed
sizes for qk_nope_head_dim, kv_lora_rank, and qk_rope_head_dim — replace the
hard "must be X" text with a statement that the function accepts the predefined
configurations (e.g., deepseek_mla_dimensions and smaller_mla_dimensions) or a
single reference to supported_mla_dimensions; mention both accepted tuples
(64/256/64 and 128/512/64) or point readers to supported_mla_dimensions so the
docstring no longer claims strict single values for qk_nope_head_dim,
kv_lora_rank, and qk_rope_head_dim.
- Around line 785-794: The call to _check_trtllm_gen_mla_shape in
xqa_batch_decode_with_kv_cache_mla uses the old 8-arg signature; construct an
MLADimensions instance from qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim
(e.g., mla_dimensions = MLADimensions(...)) and call
_check_trtllm_gen_mla_shape(query, kv_cache, mla_dimensions, 0 /*
sparse_mla_top_k */, block_tables, block_size) instead of passing the three
ints; update the call site in xqa_batch_decode_with_kv_cache_mla to use the new
6-arg signature and remove the old 8-argument ordering.
In `@tests/attention/test_trtllm_gen_mla.py`:
- Line 316: Replace the hardcoded DeepSeek v3 head-dim sum used in bmm1_scale
and sm_scale with the actual head-dimension sum from the test's attention_config
(attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) so the
denominator becomes sqrt(attention_config.qk_nope_head_dim +
attention_config.qk_rope_head_dim) (or equivalent) instead of sqrt(128 + 64);
update both occurrences that set bmm1_scale and sm_scale in
tests/attention/test_trtllm_gen_mla.py to compute the denominator from
attention_config to correctly honor smaller_mla_dimensions.
---
Nitpick comments:
In `@flashinfer/mla.py`:
- Line 554: The parameter type annotations use PEP 604 union syntax ("bool |
None") which is inconsistent with the rest of the file; change the annotations
for enable_pdl in the function signature that declares "enable_pdl: bool | None
= None" and for xqa_batch_decode_with_kv_cache_mla's corresponding parameter to
use Optional[bool] instead (i.e., "enable_pdl: Optional[bool] = None"), and
ensure Optional is imported from typing if not already used in the file.
flashinfer/artifacts.py
Outdated
| TRTLLM_GEN_FMHA: str = "e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/" | ||
| # TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/" |
There was a problem hiding this comment.
CheckSumHash.TRTLLM_GEN_FMHA not updated alongside ArtifactPath.TRTLLM_GEN_FMHA
The class docstring at line 107 explicitly says: "When updating the ArtifactPath for backend directories, update the corresponding hash." CheckSumHash.TRTLLM_GEN_FMHA (line 110) still holds the sha256 of the old checksums.txt (e86f0e45… path); it needs to be updated to the sha256 of the checksums.txt at the new e8a49d8a… path, otherwise get_subdir_file_list → verify_cubin will fail with a checksum mismatch.
If this is intentionally deferred until the new cubins are published (per the draft note), please add a # TODO: update after cubin publish comment here so the obligation is visible.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/artifacts.py` around lines 90 - 91, Update the checksum constant
to match the new ArtifactPath: compute the sha256 of the checksums.txt at the
new ArtifactPath.TRTLLM_GEN_FMHA (currently
"e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/") and replace
CheckSumHash.TRTLLM_GEN_FMHA's old value (e86f0e45…) with that new hash so
get_subdir_file_list → verify_cubin will validate correctly; if you
intentionally cannot update the hash yet, add a clear TODO comment on
CheckSumHash.TRTLLM_GEN_FMHA like “# TODO: update after cubin publish” to make
the deferred obligation explicit.
de9e12c to
b89b4b4
Compare
ec0bd61 to
3299a25
Compare
Additional MLA dimensions
Draft, pending merge of CUBIN publishing changes to main.
Summary by CodeRabbit
New Features
Improvements