Skip to content

Support BF16 MLA on SM120 with shared-mem fallback#2675

Open
maomao123321 wants to merge 1 commit intoflashinfer-ai:mainfrom
maomao123321:feat-sm120-bf16-mla
Open

Support BF16 MLA on SM120 with shared-mem fallback#2675
maomao123321 wants to merge 1 commit intoflashinfer-ai:mainfrom
maomao123321:feat-sm120-bf16-mla

Conversation

@maomao123321
Copy link
Copy Markdown

@maomao123321 maomao123321 commented Mar 3, 2026

📌 Description

This PR adds BF16 support to the XQA MLA decode kernel on SM120/SM121 while respecting the 99 KB per-block shared memory limit on consumer Blackwell GPUs. It reuses the existing MLA pipeline and MMA infrastructure, tunes the BF16 tiling and buffer configuration, and adds runtime/device capability checks so that BF16 MLA is enabled on GPUs with sufficient shared memory and falls back cleanly to FA2 on devices that cannot support the full MLA configuration.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • [1] I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • [ 1] I have installed the hooks with pre-commit install.
  • [ 1] I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • [ 1] Tests have been added or updated as needed.
  • [ 1] All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Extended MLA batch decode operations to support BF16 precision type alongside existing FP8, expanding compatibility options
    • Added device-side kernel memory validation to enforce shared memory usage constraints against device limits
  • Chores

    • Refined data type validation with enhanced error messaging documenting supported type pair combinations

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the XQA MLA decode kernel by introducing Bfloat16 (BF16) support for SM120/SM121 GPUs. The changes involve carefully tuning shared memory usage and buffer configurations to accommodate BF16 while adhering to hardware limitations. It also includes robust runtime checks to ensure compatibility and graceful fallback mechanisms, making the system more versatile for different precision requirements.

Highlights

  • BF16 Support for XQA MLA: Added support for BF16 (Bfloat16) data type in the XQA MLA decode kernel, specifically targeting SM120/SM121 GPUs.
  • Dynamic Shared Memory Configuration: Implemented dynamic adjustments to tiling and buffer configurations (e.g., partElemsK, tokensPerTile, nbKBufs) based on whether the operation uses FP8 or BF16, optimizing for shared memory usage.
  • Shared Memory Limit Adherence: Ensured that BF16 configurations respect the 99 KB per-block shared memory limit on consumer Blackwell GPUs, with an opt-in allowance for up to 101376 bytes where supported.
  • Runtime Device Capability Checks: Introduced runtime and device capability checks to enable BF16 MLA on GPUs with sufficient shared memory, providing a fallback to FA2 on devices that cannot support the full MLA configuration.
  • MMA Infrastructure Integration: Integrated BF16 support by reusing the existing MLA pipeline and MMA (Matrix Multiply-Accumulate) infrastructure, adapting it for the new data type.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/xqa/mla_sm120.cu
    • Introduced MathElem, is_fp8, and is_bf16 type aliases and constants to enable data type-dependent logic.
    • Adjusted partElemsK and tokensPerTile to be conditional on is_bf16 for optimized tiling.
    • Defined kernelQmmaShape to dynamically select MMA shapes based on is_fp8.
    • Updated Mat16x32Loader to use the new kernelQmmaShape for matrix loading.
    • Modified SharedMemA::nbKBufs to reduce K-buffers for BF16 to fit within shared memory limits.
    • Replaced hardcoded __nv_fp8_e4m3 with MathElem in mma operations within the Producer and Consumer structs.
    • Introduced a new storeOrderedXToShmBf16 function for storing BF16 data to shared memory.
    • Updated the smemSize static assertion to allow a larger shared memory limit (101376 bytes) for BF16.
    • Added runtime checks for cudaDevAttrMaxSharedMemoryPerBlockOptin and set cudaFuncAttributePreferredSharedMemoryCarveout to manage shared memory allocation for the kernel.
  • flashinfer/mla.py
    • Modified trtllm_batch_decode_with_kv_cache_mla to explicitly allow torch.bfloat16 alongside torch.float8_e4m3fn for query and KV cache dtypes.
    • Updated xqa_batch_decode_with_kv_cache_mla to support torch.bfloat16 for query and KV cache dtypes, in addition to torch.float8_e4m3fn.
Activity
  • No specific activity (comments, reviews, progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 3, 2026

📝 Walkthrough

Walkthrough

This PR extends MLA (Multi-Head Latent Attention) kernel support to BF16 precision in addition to FP8. Changes include CUDA kernel modifications that conditionally select QMMA shapes based on data type, introduce BF16-specific shared memory stores, adjust buffer layouts, and update Python API type validation to enforce (BF16, BF16) or (FP8, FP8) dtype pairs.

Changes

Cohort / File(s) Summary
BF16 MLA Kernel Implementation
csrc/xqa/mla_sm120.cu
Adds type-conditional QMMA shape selection (kernelQmmaShape), introduces new storeOrderedXToShmBf16 device method for BF16 writeback, modifies shared memory buffering logic with precision-aware nbKBufs, and replaces qmmaShape references with kernelQmmaShape throughout load/sync operations. Includes device-side memory limit validation and BF16-specific error messaging.
Python API Type Validation
flashinfer/mla.py
Updates trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla to enforce explicit dtype pair validation: accepts (fp8_e4m3fn, fp8_e4m3fn) or (bfloat16, bfloat16), rejecting other combinations with clarified error messages.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • cyx-6
  • wenscarl
  • nvmbreughe
  • bkryu
  • yzh119

Poem

🐰 A kernel blooms in BF16 hue,
Where QMMA shapes branch paths anew,
From FP8 soils, we've grown precision's tree,
Shared mem buffers dance with dtype decree,
storeOrderedXToShmBf16 writes the key! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding BF16 support to MLA on SM120 with a shared memory fallback mechanism.
Description check ✅ Passed The description includes all required sections with meaningful content: a clear explanation of changes, a related issue link, and a checklist acknowledging pre-commit and test requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces BF16 support to the XQA MLA decode kernel on SM120/SM121 GPUs, while respecting shared memory limits. It includes changes to data types, tiling configurations, and device capability checks. The code has been reviewed for correctness, efficiency, and maintainability, with a focus on potential issues related to shared memory usage and type handling.

Comment on lines +716 to 718
mma<MathElem>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]),
reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j]));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mma template is being called with MathElem which is a good abstraction. However, it's important to ensure that MathElem is correctly deduced or explicitly specified to match the expected type by the mma instruction. If MathElem is not correctly deduced, it could lead to unexpected behavior or performance degradation.

Comment on lines +1680 to +1682
// BF16 with nbKBufs=2 uses ~100KB; allow up to 99KB opt-in (101376) for devices that support it.
static constexpr uint32_t kSmemLimitBytes = is_bf16 ? 101376 : 99 * 1024;
static_assert(smemSize <= kSmemLimitBytes, "Shared memory size exceeded");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment indicates that BF16 uses ~100KB, but the limit is set to 99KB. This discrepancy should be resolved. Either the comment should be updated to reflect the actual usage, or the kSmemLimitBytes value should be adjusted if the hardware allows for the slightly larger size without performance degradation. If the 100KB usage is an estimate, it should be clarified as such.

Comment on lines +44 to +52
using MathElem = CacheElem;
inline constexpr uint32_t mathElemBytes = sizeof(MathElem);
inline constexpr bool is_fp8 = (mathElemBytes == 1);
inline constexpr bool is_bf16 = (mathElemBytes == 2);
// BF16: partElemsK=64, nbKBufs=2 → ~100KB, under 99KB opt-in (101376).
inline constexpr uint32_t partElemsK =
is_fp8 ? 64 :
is_bf16 ? 64 :
64;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conditional assignment for partElemsK based on is_fp8 and is_bf16 is a good approach to optimize shared memory usage based on the data type. However, consider adding a static assertion to ensure that the chosen value of partElemsK and nbKBufs (defined later) results in a shared memory footprint within the 99KB limit. This will provide a compile-time check against exceeding the limit.


inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes);

inline constexpr mmaShape kernelQmmaShape = is_fp8 ? mmaShape{16, 8, 32} : mmaShape{16, 8, 16};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The kernelQmmaShape definition is dependent on is_fp8. It would be beneficial to add a comment explaining why these specific shapes are chosen for FP8 and BF16, referencing any relevant documentation or performance considerations.

Comment on lines +1731 to +1734
if (size > (uint32_t)devMaxShmem) {
throw std::runtime_error(
"XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but "
"device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x).");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message is very helpful for debugging. However, consider adding the required shared memory size for FP8 as well, to provide a complete picture to the user.

throw std::runtime_error(
          "XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but "
          "device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x), FP8 needs [size] KB.");

Comment on lines 611 to 613
raise ValueError(
f"XQA MLA only supports fp8 operation on SM120/SM121 GPUs, got {query.dtype} and {kv_cache.dtype}"
f"XQA MLA on SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only, got {query.dtype} and {kv_cache.dtype}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message could be improved by explicitly stating the expected data types for query and kv_cache when BF16 is enabled. This will help users quickly identify the correct data type configuration.

raise ValueError(
                f

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/mla.py (1)

764-781: ⚠️ Potential issue | 🟠 Major

Add SM capability check to direct xqa_batch_decode_with_kv_cache_mla entrypoint.

This public API can be called directly, bypassing trtllm_batch_decode_with_kv_cache_mla. Without an explicit capability check, unsupported devices fail later with less actionable errors. The check exists in the parent function (line 602-603) but must be replicated here since this is a public entry point.

Use the existing codebase pattern for consistency:

 def xqa_batch_decode_with_kv_cache_mla(
@@
 ) -> torch.Tensor:
@@
+    if get_compute_capability(query.device)[0] != 12:
+        raise ValueError("XQA MLA is only supported on SM120/SM121 GPUs")
     enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla.py` around lines 764 - 781, Add the same SM-capability guard
used in the parent entrypoint to the public function
xqa_batch_decode_with_kv_cache_mla: call device_support_sm(query.device) early
(before dtype/block-size checks) and raise a clear ValueError if it returns
False so unsupported GPUs fail fast; follow the existing pattern used in the
other entrypoint (use device_support_sm and the same error messaging style).
🧹 Nitpick comments (1)
flashinfer/mla.py (1)

604-613: Consider centralizing XQA dtype-pair validation.

The same (fp8, fp8) / (bf16, bf16) check is duplicated in two API paths. A small helper would prevent drift in behavior and messaging.

Also applies to: 774-781

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla.py` around lines 604 - 613, Extract the duplicated dtype-pair
check (the fp8_ok/bf16_ok logic and the ValueError message that mentions "XQA
MLA on SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only") into a
single helper function (e.g., validate_xqa_dtype_pair(query, kv_cache)) and
replace both inline blocks (the one using fp8_ok/bf16_ok shown and the other at
the later location) with calls to that helper; ensure the helper performs the
same boolean checks against torch.float8_e4m3fn and torch.bfloat16 and raises
the identical ValueError message to keep behavior and messaging consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/xqa/mla_sm120.cu`:
- Around line 1728-1730: The checks for cudaDevAttrMaxSharedMemoryPerBlockOptin
use a hardcoded device 0; change them to query the currently active device first
and use that device id for the attribute call. Concretely, before calling
cudaDeviceGetAttribute in launchMLA() and configureKernel(), call
cudaGetDevice(&dev) (or an equivalent helper), then pass dev into
cudaDeviceGetAttribute(&devMaxShmem, cudaDevAttrMaxSharedMemoryPerBlockOptin,
dev); preserve existing error handling around checkCuda and ensure the variable
names (devMaxShmem, cudaDevAttrMaxSharedMemoryPerBlockOptin, launchMLA,
configureKernel) remain consistent.

---

Outside diff comments:
In `@flashinfer/mla.py`:
- Around line 764-781: Add the same SM-capability guard used in the parent
entrypoint to the public function xqa_batch_decode_with_kv_cache_mla: call
device_support_sm(query.device) early (before dtype/block-size checks) and raise
a clear ValueError if it returns False so unsupported GPUs fail fast; follow the
existing pattern used in the other entrypoint (use device_support_sm and the
same error messaging style).

---

Nitpick comments:
In `@flashinfer/mla.py`:
- Around line 604-613: Extract the duplicated dtype-pair check (the
fp8_ok/bf16_ok logic and the ValueError message that mentions "XQA MLA on
SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only") into a single
helper function (e.g., validate_xqa_dtype_pair(query, kv_cache)) and replace
both inline blocks (the one using fp8_ok/bf16_ok shown and the other at the
later location) with calls to that helper; ensure the helper performs the same
boolean checks against torch.float8_e4m3fn and torch.bfloat16 and raises the
identical ValueError message to keep behavior and messaging consistent.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a6a60f1 and 6dfb06c.

📒 Files selected for processing (2)
  • csrc/xqa/mla_sm120.cu
  • flashinfer/mla.py

Comment on lines +1728 to +1730
int devMaxShmem = 0;
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's check the file size and then examine the specific lines
wc -l csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

# Examine the code around lines 1728-1730
sed -n '1720,1740p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1167


🏁 Script executed:

# Examine the code around lines 1831-1833
sed -n '1823,1843p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 962


🏁 Script executed:

# Run the verification command from the review comment
rg -n -C2 'cudaDeviceGetAttribute\(&devMaxShmem,\s*cudaDevAttrMaxSharedMemoryPerBlockOptin,\s*0\)' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Also check for the pattern more broadly to understand context
rg -n 'devMaxShmem' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 547


🏁 Script executed:

# Get more context around the first occurrence (lambda function)
sed -n '1710,1750p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 2243


🏁 Script executed:

# Get function signature and more context for configureKernel
sed -n '1825,1855p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1539


🏁 Script executed:

# Search for where configureKernel and the lambda are called
rg -n 'configureKernel\|hostSmemSize' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if cudaGetDevice is already used elsewhere in the file
rg -n 'cudaGetDevice' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for cudaSetDevice to understand multi-GPU context
rg -n 'cudaSetDevice\|device.*=' csrc/xqa/mla_sm120.cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for the kernel launch function that contains the first occurrence
rg -n 'void.*launchMLAFlashInfer\|void.*launchXQA' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what function contains the lambda at lines 1728-1730
sed -n '1700,1715p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 967


🏁 Script executed:

# Look for how this file is used in other places
fd -e cu -e h | head -20 | xargs rg -l 'launchMLAFlashInfer\|configureKernel' 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 52


🏁 Script executed:

# Check if there's a pattern for device management in the codebase
rg -n 'cudaSetDevice|getActiveDevice' csrc/

Repository: flashinfer-ai/flashinfer

Length of output: 675


🏁 Script executed:

# Look at the actual function signature containing the lambda
sed -n '1655,1730p' csrc/xqa/mla_sm120.cu | tail -30

Repository: flashinfer-ai/flashinfer

Length of output: 1578


Use the currently active CUDA device instead of hardcoded device 0 for SMEM capability checks.

Lines 1728-1730 and 1831-1833 query cudaDevAttrMaxSharedMemoryPerBlockOptin on device 0. On multi-GPU runs, these static initializers may execute with a different active device than intended, causing validation against the wrong GPU's capabilities and leading to incorrect pass/fail behavior.

The first check runs in a static lambda inside launchMLA() on first invocation, and the second runs at module initialization in configureKernel(), both potentially with undefined device context.

🔧 Proposed fix
-    int devMaxShmem = 0;
-    checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
-                                      cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+    int devMaxShmem = 0;
+    int device = -1;
+    checkCuda(cudaGetDevice(&device));
+    checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+                                     cudaDevAttrMaxSharedMemoryPerBlockOptin,
+                                     device));
-  int devMaxShmem = 0;
-  checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
-                                    cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+  int devMaxShmem = 0;
+  int device = -1;
+  checkCuda(cudaGetDevice(&device));
+  checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+                                   cudaDevAttrMaxSharedMemoryPerBlockOptin,
+                                   device));
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int devMaxShmem = 0;
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/xqa/mla_sm120.cu` around lines 1728 - 1730, The checks for
cudaDevAttrMaxSharedMemoryPerBlockOptin use a hardcoded device 0; change them to
query the currently active device first and use that device id for the attribute
call. Concretely, before calling cudaDeviceGetAttribute in launchMLA() and
configureKernel(), call cudaGetDevice(&dev) (or an equivalent helper), then pass
dev into cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); preserve existing error handling
around checkCuda and ensure the variable names (devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, launchMLA, configureKernel) remain
consistent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saltyminty, thanks for confirming! Here's the fix applied to both locations:

In launchMLA() (static lambda, ~line 1728):

int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
                                 cudaDevAttrMaxSharedMemoryPerBlockOptin,
                                 device));

In configureKernel() (~line 1831):

int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
                                 cudaDevAttrMaxSharedMemoryPerBlockOptin,
                                 device));

Note that configureKernel() runs at static initialization time (static uint32_t const hostSmemSize = configureKernel();), so the active device context there may be particularly unpredictable. You may also want to ensure that hostSmemSize is not cached as a file-scope static if multi-GPU support with different GPU models is needed — or document the assumption that all devices in a multi-GPU setup are homogeneous.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-03-04T05:20:26.963Z
Learning: Applies to csrc/*.cu : Implement kernel launchers in `csrc/` that handle framework-specific tensor operations

Learnt from: depaulmillz
Repo: flashinfer-ai/flashinfer PR: 2738
File: csrc/group_gemm_nvfp4_groupwise_sm120.cu:101-102
Timestamp: 2026-03-11T16:24:12.197Z
Learning: In CUDA source files under csrc (e.g., csrc/group_gemm_nvfp4_groupwise_sm120.cu and similar), it is valid and intentional that int_workspace_buffer and float_workspace_buffer are allocated on the same device as input tensor a via _get_cache_buf(..., a.device), and that CUDADeviceGuard is sourced from float_workspace_buffer.device() with the stream from A.device(). Do not flag these as device inconsistencies; instead, verify actual inconsistencies elsewhere and rely on this established pattern.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2709
File: include/flashinfer/mamba/seq_chunk_cumsum.cuh:0-0
Timestamp: 2026-03-06T20:52:57.849Z
Learning: In `include/flashinfer/mamba/seq_chunk_cumsum.cuh` and `csrc/seq_chunk_cumsum.cu`, the maintainer explicitly does not want runtime validation of metadata (chunk_indices, chunk_offsets, seq_idx bounds, monotonicity) in the kernel launcher or device code because this is a high-throughput kernel. Do not suggest adding such checks. Debug-mode assertions may be acceptable but should not be pushed.

Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Learnt from: xrq-phys
Repo: flashinfer-ai/flashinfer PR: 2711
File: csrc/trtllm_fmha_kernel_launcher.cu:552-563
Timestamp: 2026-03-07T06:34:53.719Z
Learning: In `csrc/trtllm_fmha_kernel_launcher.cu` (flashinfer-ai/flashinfer), dtype validation for SageAttention scaling-factor tensors (`sage_attn_sfs_q/k/p/v`) is intentionally absent. This file is a TVM FFI path (not a PyTorch extension path), and dtype validation is expected to be handled at a different layer/entry point. Do not flag missing `TVM_FFI_ICHECK_EQ(...dtype(), dl_float32)` checks for these tensors in this file.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` → `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.

Learnt from: danisereb
Repo: flashinfer-ai/flashinfer PR: 2464
File: include/flashinfer/gemm/mxfp8_gemm_template_sm100.h:148-163
Timestamp: 2026-02-04T10:08:47.455Z
Learning: In flashinfer GEMM template implementations (e.g., include/flashinfer/gemm/fp4_gemm_template_sm100.h, mxfp8_gemm_template_sm100.h), the Sm10x11xOnly architecture check wrapper uses a pattern where only thread0() prints an error message and calls __trap() when running on unsupported architectures. This pattern is intentional and working in production code, so consistency should be maintained across similar implementations.

blake-snc added a commit to blake-snc/flashinfer that referenced this pull request Mar 4, 2026
Validated on SM121a (DGX Spark GB10). The original PR flashinfer-ai#2675 kernel produces
100% NaN output for BF16. These fixes make it fully correct (max_diff < 11
microunits vs PyTorch reference across batch sizes 1/2/4, seq lengths
128/256/512/1024).

Bugs fixed:

1. Missing MLA_BF16 preprocessor flag in defines.h — BF16 MLA was compiling
   with FP8 INPUT_ELEM types, causing type mismatches throughout the kernel.

2. FP8-only JIT assertions in jit/xqa.py — gen_xqa_module_mla() asserted
   input_dtype == fp8, blocking BF16 compilation entirely. Added bf16 to
   allowed dtypes and set -DMLA_BF16=1 flag.

3. Q tensor map hardcoded 64B swizzle — partElemsK=64 with 2-byte BF16 =
   128 bytes per partition, requiring 128B swizzle. Made swizzle dynamic
   based on partBytes.

4. V tensor map 256-byte box exceeds max swizzle — partElemsV=128 with BF16
   = 256 bytes, but max TMA swizzle is 128B. Reduced partElemsV to 64 for
   BF16 (128 bytes, matches 128B swizzle).

5. Consumer .b8 ldmatrix transpose scrambles BF16 — ldmatrix_16x16_trans
   uses .b8 which byte-transposes, scrambling 2-byte BF16 values. Replaced
   with ldmatrix<true, 2> (.b16 transpose) for BF16 path.

6. Consumer OOB access rows 16-47 in 32-row buffer — FP8 has tokensPerTile=64
   but BF16 has tokensPerTile=32. The consumer V loading iterated over 48 rows
   (warpTileNbAtomBx2=3), accessing rows beyond the 32-row buffer. Restructured
   BF16 consumer to iterate over V parts within the 32-row tile.

7. V buffer 4-part layout incompatible with single-part consumer — FP8 uses
   partElemsV=128 (4 parts per V head), but BF16 needs partElemsV=64 to fit
   swizzle. Adjusted V splitting to match.

8. Register pressure causes stack overflow crash — BF16 doubles register
   usage vs FP8 for Q cache. Reduced buffer counts to stay within register
   budget.

9. storeOrderedXToShmBf16 OOB WarpAcc indexing — original implementation
   indexed WarpAcc as src(row/2, col/2)(row%2, col%2) which doesn't match
   MMA accumulator layout. Rewrote to use correct MMA register mapping:
   src(instM, instN)(iM, iN) at row=instM*16+lane/4+iM*8.

10. Q register prefetch idxAtomBx2==2 never triggers for BF16 — the GEMM0
    inner loop prefetches Q registers at idxAtomBx2==2, but BF16 has
    tileNbAtomBx2=2 (range 0..1), so the condition never fires. regQBuf[1..3]
    stay uninitialized → garbage GEMM0 → NaN softmax → NaN output. Fixed with
    constexpr qPrefetchAtomBx2 = min(2, tileNbAtomBx2-1).

Validation results on SM121a (DGX Spark):
  B=1 seq=128:  PASS, max_diff=0.000011, NaN=0
  B=1 seq=256:  PASS, max_diff=0.000006, NaN=0
  B=1 seq=512:  PASS, max_diff=0.000005, NaN=0
  B=1 seq=1024: PASS, max_diff=0.000004, NaN=0
  B=2 seq=128:  PASS, max_diff=0.000011, NaN=0
  B=2 seq=256:  PASS, max_diff=0.000007, NaN=0
  B=4 seq=128:  PASS, max_diff=0.000010, NaN=0
  B=4 seq=1024: PASS, max_diff=0.000005, NaN=0

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 5, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !378 has been created, and the CI pipeline #45373775 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #45373775: 9/20 passed

@saltyminty
Copy link
Copy Markdown
Collaborator

saltyminty commented Mar 12, 2026

Right now, in flashinfer/jit/xqa.py::gen_xqa_module_mla, kv is still asserted to be fp8. Is this an issue? It seems this will also cause relevant tests to be skipped.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants