Skip to content

Additional kernels#2582

Closed
jdebache wants to merge 1 commit intoflashinfer-ai:mainfrom
jdebache:more_mla_dims
Closed

Additional kernels#2582
jdebache wants to merge 1 commit intoflashinfer-ai:mainfrom
jdebache:more_mla_dims

Conversation

@jdebache
Copy link
Copy Markdown
Contributor

@jdebache jdebache commented Feb 18, 2026

Additional MLA dimensions

Draft, pending merge of CUBIN publishing changes to main.

Summary by CodeRabbit

  • New Features

    • Expanded support for additional head dimension configurations in multi-head latent attention (MLA) models
    • Introduced structured MLA dimension configuration with predefined presets to streamline model setup
  • Improvements

    • Enhanced kernel dispatch and selection logic to support broader model architecture variations
    • Improved kernel metadata debugging capabilities

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 18, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The PR extends FMHA reduction kernel support for variable head dimensions, introduces an MLADimensions data structure for structured MLA configuration handling, expands kernel selection to accept both (576, 512) and (320, 256) head dimension pairs, and parametrizes tests to validate multiple MLA dimension configurations.

Changes

Cohort / File(s) Summary
Kernel Dispatch and Selection
csrc/fmhaReduction.cu, include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Updated FMHA reduction kernel selector to accept two parameters (HeadDim, HeadDimPerCta) instead of one; expanded kernel selection to support both (576, 512) and (320, 256) head dimension pairs; added verbose debug logging for kernel conflicts and metadata tracking; adjusted headDimPerCtaV handling for 2-CTA MlaGen kernels.
Runtime Validation
csrc/trtllm_fmha_kernel_launcher.cu
Broadened MLA configuration validation to accept (320, 256) head dimensions in addition to (576, 512) for both encode and decode paths.
MLA Dimension Configuration
flashinfer/mla.py
Introduced MLADimensions immutable data structure and predefined presets (deepseek_mla_dimensions, smaller_mla_dimensions); replaced multiple discrete dimension parameters with single mla_dimensions parameter; enhanced validation to handle both 3D and 4D kv_cache inputs.
Artifact Path
flashinfer/artifacts.py
Updated TRTLLM_GEN_FMHA artifact hash constant with comment preserving previous value.
Test Parametrization
tests/attention/test_trtllm_gen_mla.py
Extended tests to dynamically parameterize with MLADimensions configurations; replaced hard-coded MLA dimension constants with configuration-driven values; added test coverage for multiple MLA dimension presets.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • PerkzZheng
  • bkryu
  • nvmbreughe
  • aleozlx
  • cyx-6
  • djmmoss
  • wenscarl
  • yzh119

Poem

🐰 Two dimensions dance with grace,
576 meets 320's face,
Kernels dispatch with structured care,
MLADimensions float through air,
Tests bloom bright in parametric bloom! ✨

🚥 Pre-merge checks | ❌ 3

❌ Failed checks (2 warnings, 1 inconclusive)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is minimal (one line) and does not follow the provided template structure; it lacks detailed explanation of changes, rationale, or testing information. Complete the description template by adding a detailed explanation of what changes are made, why they're needed, and confirming that tests have been added and pre-commit checks are passing.
Docstring Coverage ⚠️ Warning Docstring coverage is 36.36% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title "Additional kernels" is vague and does not clearly convey the primary change; it refers to MLA dimension support but lacks specificity. Consider a more specific title like "Support additional MLA dimensions" or "Add MLA kernel variants for smaller dimensions" that better describes the actual changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @hypdeb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and robustness of the FlashInfer library by extending support for Multi-Layer Attention (MLA) kernels to include additional head dimension configurations. It refactors the kernel selection mechanism to be more dynamic, abstracts MLA dimensions into a dedicated Python dataclass for better management, and updates the underlying compiled CUDA binaries. These changes enable the library to accommodate a wider range of model architectures and improve the overall maintainability and testability of the MLA implementation.

Highlights

  • Expanded MLA Kernel Support: Added support for new Multi-Layer Attention (MLA) dimensions, specifically a (320, 256) head dimension pair, alongside the existing (576, 512) configuration, allowing for greater flexibility in model architectures.
  • Dynamic Kernel Selection: Refactored the SELECT_FMHA_REDUCTION_KERNEL macro and its dispatch logic to dynamically select kernels based on the actual HeadDim (mHeadDimV) rather than a hardcoded value, improving adaptability to various head dimensions.
  • MLA Dimension Abstraction: Introduced an MLADimensions dataclass in Python to encapsulate and manage different MLA configurations, centralizing dimension definitions and improving type safety and readability in the API.
  • Updated CUBIN Artifacts: Updated the TRTLLM_GEN_FMHA CUBIN path to incorporate newly compiled kernels that support the expanded MLA dimensions.
  • Enhanced Debugging and Testing: Improved kernel metadata logging with a kernelMetaToString helper for better debugging of hash conflicts and extended unit tests to cover the newly supported MLA dimensions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/fmhaReduction.cu
    • Modified SELECT_FMHA_REDUCTION_KERNEL macro to accept HeadDim as a parameter, replacing a hardcoded value.
    • Expanded FLASHINFER_CHECK to include (320, 256) as a valid mHeadDimQk and mHeadDimV pair.
    • Updated kernel selection logic in runFmhaReduction to dispatch based on kernelMeta.mHeadDimV.
  • csrc/trtllm_fmha_kernel_launcher.cu
    • Extended TVM_FFI_ICHECK for sparse MLA to recognize (320, 256) head dimensions.
    • Added (320, 256) as a valid head dimension pair for non-MLA attention checks.
  • flashinfer/artifacts.py
    • Updated the SHA for TRTLLM_GEN_FMHA to point to new CUBIN artifacts.
  • flashinfer/mla.py
    • Imported dataclass for defining structured data.
    • Defined MLADimensions dataclass and added deepseek_mla_dimensions, smaller_mla_dimensions, and supported_mla_dimensions constants.
    • Refactored _check_trtllm_gen_mla_shape to accept an MLADimensions object for validation.
    • Updated type hint for enable_pdl parameter in trtllm_batch_decode_with_kv_cache_mla to bool | None.
    • Modified trtllm_batch_decode_with_kv_cache_mla to construct and pass an MLADimensions object.
    • Updated type hint for enable_pdl parameter in xqa_batch_decode_with_kv_cache_mla to bool | None.
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
    • Included <sstream> for string manipulation.
    • Added kernelMetaToString helper function for debugging TllmGenFmhaKernelMetaInfo.
    • Enhanced logging and error messages in TllmGenFmhaKernel using kernelMetaToString.
    • Updated isMlaGenKernel to check for (320, 256) head dimensions.
    • Changed mHeadDimPerCtaV calculation to be dynamic based on params.mHeadDimV.
  • tests/attention/test_trtllm_gen_mla.py
    • Imported MLADimensions, deepseek_mla_dimensions, and smaller_mla_dimensions.
    • Updated trtllm_batch_decode_mla function signature to accept attention_config and num_q_heads.
    • Replaced hardcoded attention config parameters with the attention_config object.
    • Adjusted kv_cache size calculation to use attention_config.
    • Modified calls to flashinfer.dispatch.trtllm_batch_decode_with_kv_cache_mla to pass attention_config fields.
    • Updated reference calculation for q_nope, q_pe, ckv, and kpe to use attention_config.
    • Added pytest parameterization for mla_dimensions to test different configurations.
    • Updated test function calls to pass the mla_dimensions and num_q_heads.
Activity
  • The pull request is currently a draft, awaiting the merge of CUBIN publishing changes to the main branch.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for additional MLA dimensions, specifically for a smaller configuration (qk=320, v=256). The changes are propagated through the CUDA kernels, launchers, and Python wrappers. The test suite is also updated to cover these new dimensions.

My review has identified a few critical issues. In csrc/fmhaReduction.cu, the kernel selection logic is missing default cases, which could lead to a crash if unsupported dimensions are used. In include/flashinfer/trtllm/fmha/fmhaKernels.cuh, a debugging macro is used incorrectly, which would result in uninformative error messages. I've also pointed out some minor code cleanup in flashinfer/artifacts.py. Overall, the changes are in the right direction but the critical issues should be addressed before merging.

Comment on lines 355 to 369
if (kernelMeta.mHeadDimV == 512) {
if (headDimPerCtaV == 128) {
SELECT_FMHA_REDUCTION_KERNEL(512, 128);
} else if (headDimPerCtaV == 256) {
SELECT_FMHA_REDUCTION_KERNEL(512, 256);
} else if (headDimPerCtaV == 512) {
SELECT_FMHA_REDUCTION_KERNEL(512, 512);
}
} else if (kernelMeta.mHeadDimV == 256) {
if (headDimPerCtaV == 128) {
SELECT_FMHA_REDUCTION_KERNEL(256, 128);
} else if (headDimPerCtaV == 256) {
SELECT_FMHA_REDUCTION_KERNEL(256, 256);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The kernel selection logic is missing default cases. If kernelMeta.mHeadDimV is not 512 or 256, or if headDimPerCtaV has an unexpected value for a given mHeadDimV, the kernel pointer will remain nullptr, leading to a crash when cudaLaunchKernelEx is called. It's safer to add else blocks with FLASHINFER_CHECK(false, ...) to handle unsupported dimension combinations explicitly.

  if (kernelMeta.mHeadDimV == 512) {
    if (headDimPerCtaV == 128) {
      SELECT_FMHA_REDUCTION_KERNEL(512, 128);
    } else if (headDimPerCtaV == 256) {
      SELECT_FMHA_REDUCTION_KERNEL(512, 256);
    } else if (headDimPerCtaV == 512) {
      SELECT_FMHA_REDUCTION_KERNEL(512, 512);
    } else {
      FLASHINFER_CHECK(false, "Unsupported headDimPerCtaV for HeadDimV=512");
    }
  } else if (kernelMeta.mHeadDimV == 256) {
    if (headDimPerCtaV == 128) {
      SELECT_FMHA_REDUCTION_KERNEL(256, 128);
    } else if (headDimPerCtaV == 256) {
      SELECT_FMHA_REDUCTION_KERNEL(256, 256);
    } else {
      FLASHINFER_CHECK(false, "Unsupported headDimPerCtaV for HeadDimV=256");
    }
  } else {
    FLASHINFER_CHECK(false, "Unsupported HeadDimV");
  }

Comment on lines +165 to +169
FLASHINFER_CHECK(isFamilySpecificSMPair(existingKernelMeta.mSM, kernelMeta.mSM),
"Hash conflicts exist between %s and %s.", existingKernelMeta.mFuncName,
kernelMeta.mFuncName);
"Hash conflicts exist between:\n existing: ",
kernelMetaToString(existingKernelMeta),
"\n new: ",
kernelMetaToString(kernelMeta));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The usage of FLASHINFER_CHECK seems incorrect here. This macro appears to be printf-style, but the new code provides multiple string arguments without format specifiers. This will likely result in an uninformative error message, as the additional arguments will be ignored. The error message should be a single format string with arguments, and std::string objects should be converted to const char* using .c_str().

          FLASHINFER_CHECK(isFamilySpecificSMPair(existingKernelMeta.mSM, kernelMeta.mSM),
                           "Hash conflicts exist between:\n  existing: %s\n  new:      %s",
                           kernelMetaToString(existingKernelMeta).c_str(),
                           kernelMetaToString(kernelMeta).c_str());


TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/"
# TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please remove the commented-out old TRTLLM_GEN_FMHA path before merging to keep the code clean.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tests/attention/test_trtllm_gen_mla.py (1)

316-316: ⚠️ Potential issue | 🟡 Minor

Hardcoded (128 + 64) scale denominator ignores attention_config for smaller_mla_dimensions

Both bmm1_scale (line 316) and sm_scale (lines 326–327) use the hardcoded DeepSeek v3 head dim (128 + 64) instead of attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim. For smaller_mla_dimensions (qk_nope=64, qk_rope=64) the denominator should be sqrt(128) not sqrt(192). Since both the kernel and the reference wrapper use the same value, the test comparison is still self-consistent, but the scale is semantically wrong and could mask actual numerical issues.

🐛 Proposed fix
-        bmm1_scale=scale / ((128 + 64) ** 0.5),
+        bmm1_scale=scale / ((attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) ** 0.5),
-    sm_scale = scale / (
-        (128 + 64) ** 0.5
-    )  # use head dimension before matrix absorption
+    sm_scale = scale / (
+        (attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) ** 0.5
+    )  # use head dimension before matrix absorption
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` at line 316, Replace the hardcoded
DeepSeek v3 head-dim sum used in bmm1_scale and sm_scale with the actual
head-dimension sum from the test's attention_config
(attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) so the
denominator becomes sqrt(attention_config.qk_nope_head_dim +
attention_config.qk_rope_head_dim) (or equivalent) instead of sqrt(128 + 64);
update both occurrences that set bmm1_scale and sm_scale in
tests/attention/test_trtllm_gen_mla.py to compute the denominator from
attention_config to correctly honor smaller_mla_dimensions.
csrc/fmhaReduction.cu (1)

311-312: ⚠️ Potential issue | 🟡 Minor

headDimPerCtaV==512 is permitted by the guard but unhandled in the mHeadDimV==256 dispatch branch — null kernel dereference if reached

The FLASHINFER_CHECK at line 311 accepts headDimPerCtaV ∈ {128, 256, 512} unconditionally, but the dispatch tree for mHeadDimV == 256 (lines 363–368) only covers 128 and 256. If headDimPerCtaV == 512 is somehow reached with mHeadDimV == 256, kernel stays nullptr and cuLaunchKernelEx at line 372 will crash. Adding a null guard and a tighter check will eliminate the risk:

🛡️ Proposed defensive fix
   FLASHINFER_CHECK(headDimPerCtaV == 128 || headDimPerCtaV == 256 || headDimPerCtaV == 512,
                    "Not implemented");

And after the dispatch block:

+  FLASHINFER_CHECK(kernel != nullptr,
+                   "Unsupported (mHeadDimV=%d, headDimPerCtaV=%d) combination in fmhaReduction",
+                   kernelMeta.mHeadDimV, headDimPerCtaV);

   // Launch the kernel.
   cudaLaunchKernelEx(&config, kernel, params, ...);

Also applies to: 363-372

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fmhaReduction.cu` around lines 311 - 312, The guard allows
headDimPerCtaV==512 but the mHeadDimV==256 dispatch branch only sets kernel for
128 and 256, risking a null kernel and crash at cuLaunchKernelEx; update the
dispatch in the mHeadDimV==256 case (the branch that inspects headDimPerCtaV and
assigns kernel) to either handle the 512 case properly or tighten the initial
FLASHINFER_CHECK to forbid 512 when mHeadDimV==256, and add a defensive null
check before cuLaunchKernelEx (verify kernel != nullptr and return/log an error)
to prevent dereference if an unsupported headDimPerCtaV slips through.
flashinfer/mla.py (2)

562-566: ⚠️ Potential issue | 🟡 Minor

Stale docstring — dimension constraints no longer accurate

The docstring claims qk_nope_head_dim "must be 128", kv_lora_rank "must be 512", and qk_rope_head_dim "must be 64", but the function now supports both deepseek_mla_dimensions and smaller_mla_dimensions (64/256/64). Update to reference supported_mla_dimensions or list both accepted configurations.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla.py` around lines 562 - 566, Update the stale docstring that
currently asserts fixed sizes for qk_nope_head_dim, kv_lora_rank, and
qk_rope_head_dim — replace the hard "must be X" text with a statement that the
function accepts the predefined configurations (e.g., deepseek_mla_dimensions
and smaller_mla_dimensions) or a single reference to supported_mla_dimensions;
mention both accepted tuples (64/256/64 and 128/512/64) or point readers to
supported_mla_dimensions so the docstring no longer claims strict single values
for qk_nope_head_dim, kv_lora_rank, and qk_rope_head_dim.

785-794: ⚠️ Potential issue | 🔴 Critical

_check_trtllm_gen_mla_shape called with the old 8-arg signature — breaks CI

xqa_batch_decode_with_kv_cache_mla was not updated when _check_trtllm_gen_mla_shape was refactored from (query, kv_cache, qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, sparse_mla_top_k, page_table, page_size) to (query, kv_cache, mla_dimensions: MLADimensions, sparse_mla_top_k, page_table, page_size). The CI pipeline confirms this with "Too many arguments" and "Argument 3 has incompatible type int; expected MLADimensions".

🐛 Proposed fix
-    kv_cache = _check_trtllm_gen_mla_shape(
-        query,
-        kv_cache,
-        qk_nope_head_dim,
-        kv_lora_rank,
-        qk_rope_head_dim,
-        0,  # sparse_mla_top_k
-        block_tables,
-        block_size,
-    )
+    kv_cache = _check_trtllm_gen_mla_shape(
+        query,
+        kv_cache,
+        MLADimensions(
+            qk_nope_head_dim=qk_nope_head_dim,
+            qk_rope_head_dim=qk_rope_head_dim,
+            kv_lora_rank=kv_lora_rank,
+        ),
+        0,  # sparse_mla_top_k
+        block_tables,
+        block_size,
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla.py` around lines 785 - 794, The call to
_check_trtllm_gen_mla_shape in xqa_batch_decode_with_kv_cache_mla uses the old
8-arg signature; construct an MLADimensions instance from qk_nope_head_dim,
kv_lora_rank, qk_rope_head_dim (e.g., mla_dimensions = MLADimensions(...)) and
call _check_trtllm_gen_mla_shape(query, kv_cache, mla_dimensions, 0 /*
sparse_mla_top_k */, block_tables, block_size) instead of passing the three
ints; update the call site in xqa_batch_decode_with_kv_cache_mla to use the new
6-arg signature and remove the old 8-argument ordering.
🧹 Nitpick comments (1)
flashinfer/mla.py (1)

554-554: bool | None mixes union-syntax styles with the rest of the file

enable_pdl: bool | None = None (Python 3.10+ PEP 604 syntax) is inconsistent with the rest of mla.py which imports and uses Optional from typing. Use Optional[bool] for consistency.

-    enable_pdl: bool | None = None,
+    enable_pdl: Optional[bool] = None,

Same applies to xqa_batch_decode_with_kv_cache_mla at line 734.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla.py` at line 554, The parameter type annotations use PEP 604
union syntax ("bool | None") which is inconsistent with the rest of the file;
change the annotations for enable_pdl in the function signature that declares
"enable_pdl: bool | None = None" and for xqa_batch_decode_with_kv_cache_mla's
corresponding parameter to use Optional[bool] instead (i.e., "enable_pdl:
Optional[bool] = None"), and ensure Optional is imported from typing if not
already used in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/artifacts.py`:
- Around line 90-91: Update the checksum constant to match the new ArtifactPath:
compute the sha256 of the checksums.txt at the new ArtifactPath.TRTLLM_GEN_FMHA
(currently "e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/") and
replace CheckSumHash.TRTLLM_GEN_FMHA's old value (e86f0e45…) with that new hash
so get_subdir_file_list → verify_cubin will validate correctly; if you
intentionally cannot update the hash yet, add a clear TODO comment on
CheckSumHash.TRTLLM_GEN_FMHA like “# TODO: update after cubin publish” to make
the deferred obligation explicit.

---

Outside diff comments:
In `@csrc/fmhaReduction.cu`:
- Around line 311-312: The guard allows headDimPerCtaV==512 but the
mHeadDimV==256 dispatch branch only sets kernel for 128 and 256, risking a null
kernel and crash at cuLaunchKernelEx; update the dispatch in the mHeadDimV==256
case (the branch that inspects headDimPerCtaV and assigns kernel) to either
handle the 512 case properly or tighten the initial FLASHINFER_CHECK to forbid
512 when mHeadDimV==256, and add a defensive null check before cuLaunchKernelEx
(verify kernel != nullptr and return/log an error) to prevent dereference if an
unsupported headDimPerCtaV slips through.

In `@flashinfer/mla.py`:
- Around line 562-566: Update the stale docstring that currently asserts fixed
sizes for qk_nope_head_dim, kv_lora_rank, and qk_rope_head_dim — replace the
hard "must be X" text with a statement that the function accepts the predefined
configurations (e.g., deepseek_mla_dimensions and smaller_mla_dimensions) or a
single reference to supported_mla_dimensions; mention both accepted tuples
(64/256/64 and 128/512/64) or point readers to supported_mla_dimensions so the
docstring no longer claims strict single values for qk_nope_head_dim,
kv_lora_rank, and qk_rope_head_dim.
- Around line 785-794: The call to _check_trtllm_gen_mla_shape in
xqa_batch_decode_with_kv_cache_mla uses the old 8-arg signature; construct an
MLADimensions instance from qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim
(e.g., mla_dimensions = MLADimensions(...)) and call
_check_trtllm_gen_mla_shape(query, kv_cache, mla_dimensions, 0 /*
sparse_mla_top_k */, block_tables, block_size) instead of passing the three
ints; update the call site in xqa_batch_decode_with_kv_cache_mla to use the new
6-arg signature and remove the old 8-argument ordering.

In `@tests/attention/test_trtllm_gen_mla.py`:
- Line 316: Replace the hardcoded DeepSeek v3 head-dim sum used in bmm1_scale
and sm_scale with the actual head-dimension sum from the test's attention_config
(attention_config.qk_nope_head_dim + attention_config.qk_rope_head_dim) so the
denominator becomes sqrt(attention_config.qk_nope_head_dim +
attention_config.qk_rope_head_dim) (or equivalent) instead of sqrt(128 + 64);
update both occurrences that set bmm1_scale and sm_scale in
tests/attention/test_trtllm_gen_mla.py to compute the denominator from
attention_config to correctly honor smaller_mla_dimensions.

---

Nitpick comments:
In `@flashinfer/mla.py`:
- Line 554: The parameter type annotations use PEP 604 union syntax ("bool |
None") which is inconsistent with the rest of the file; change the annotations
for enable_pdl in the function signature that declares "enable_pdl: bool | None
= None" and for xqa_batch_decode_with_kv_cache_mla's corresponding parameter to
use Optional[bool] instead (i.e., "enable_pdl: Optional[bool] = None"), and
ensure Optional is imported from typing if not already used in the file.

Comment on lines +90 to +91
TRTLLM_GEN_FMHA: str = "e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/"
# TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

CheckSumHash.TRTLLM_GEN_FMHA not updated alongside ArtifactPath.TRTLLM_GEN_FMHA

The class docstring at line 107 explicitly says: "When updating the ArtifactPath for backend directories, update the corresponding hash." CheckSumHash.TRTLLM_GEN_FMHA (line 110) still holds the sha256 of the old checksums.txt (e86f0e45… path); it needs to be updated to the sha256 of the checksums.txt at the new e8a49d8a… path, otherwise get_subdir_file_listverify_cubin will fail with a checksum mismatch.

If this is intentionally deferred until the new cubins are published (per the draft note), please add a # TODO: update after cubin publish comment here so the obligation is visible.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/artifacts.py` around lines 90 - 91, Update the checksum constant
to match the new ArtifactPath: compute the sha256 of the checksums.txt at the
new ArtifactPath.TRTLLM_GEN_FMHA (currently
"e8a49d8aaab679fad8f52e696f8ba2bda01613c3/fmha/trtllm-gen/") and replace
CheckSumHash.TRTLLM_GEN_FMHA's old value (e86f0e45…) with that new hash so
get_subdir_file_list → verify_cubin will validate correctly; if you
intentionally cannot update the hash yet, add a clear TODO comment on
CheckSumHash.TRTLLM_GEN_FMHA like “# TODO: update after cubin publish” to make
the deferred obligation explicit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant