Support BF16 MLA on SM120 with shared-mem fallback#2675
Support BF16 MLA on SM120 with shared-mem fallback#2675maomao123321 wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the XQA MLA decode kernel by introducing Bfloat16 (BF16) support for SM120/SM121 GPUs. The changes involve carefully tuning shared memory usage and buffer configurations to accommodate BF16 while adhering to hardware limitations. It also includes robust runtime checks to ensure compatibility and graceful fallback mechanisms, making the system more versatile for different precision requirements. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThis PR extends MLA (Multi-Head Latent Attention) kernel support to BF16 precision in addition to FP8. Changes include CUDA kernel modifications that conditionally select QMMA shapes based on data type, introduce BF16-specific shared memory stores, adjust buffer layouts, and update Python API type validation to enforce (BF16, BF16) or (FP8, FP8) dtype pairs. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces BF16 support to the XQA MLA decode kernel on SM120/SM121 GPUs, while respecting shared memory limits. It includes changes to data types, tiling configurations, and device capability checks. The code has been reviewed for correctness, efficiency, and maintainability, with a focus on potential issues related to shared memory usage and type handling.
| mma<MathElem>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)), | ||
| reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]), | ||
| reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j])); |
There was a problem hiding this comment.
The mma template is being called with MathElem which is a good abstraction. However, it's important to ensure that MathElem is correctly deduced or explicitly specified to match the expected type by the mma instruction. If MathElem is not correctly deduced, it could lead to unexpected behavior or performance degradation.
| // BF16 with nbKBufs=2 uses ~100KB; allow up to 99KB opt-in (101376) for devices that support it. | ||
| static constexpr uint32_t kSmemLimitBytes = is_bf16 ? 101376 : 99 * 1024; | ||
| static_assert(smemSize <= kSmemLimitBytes, "Shared memory size exceeded"); |
There was a problem hiding this comment.
The comment indicates that BF16 uses ~100KB, but the limit is set to 99KB. This discrepancy should be resolved. Either the comment should be updated to reflect the actual usage, or the kSmemLimitBytes value should be adjusted if the hardware allows for the slightly larger size without performance degradation. If the 100KB usage is an estimate, it should be clarified as such.
| using MathElem = CacheElem; | ||
| inline constexpr uint32_t mathElemBytes = sizeof(MathElem); | ||
| inline constexpr bool is_fp8 = (mathElemBytes == 1); | ||
| inline constexpr bool is_bf16 = (mathElemBytes == 2); | ||
| // BF16: partElemsK=64, nbKBufs=2 → ~100KB, under 99KB opt-in (101376). | ||
| inline constexpr uint32_t partElemsK = | ||
| is_fp8 ? 64 : | ||
| is_bf16 ? 64 : | ||
| 64; |
There was a problem hiding this comment.
The conditional assignment for partElemsK based on is_fp8 and is_bf16 is a good approach to optimize shared memory usage based on the data type. However, consider adding a static assertion to ensure that the chosen value of partElemsK and nbKBufs (defined later) results in a shared memory footprint within the 99KB limit. This will provide a compile-time check against exceeding the limit.
|
|
||
| inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes); | ||
|
|
||
| inline constexpr mmaShape kernelQmmaShape = is_fp8 ? mmaShape{16, 8, 32} : mmaShape{16, 8, 16}; |
| if (size > (uint32_t)devMaxShmem) { | ||
| throw std::runtime_error( | ||
| "XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but " | ||
| "device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x)."); |
There was a problem hiding this comment.
The error message is very helpful for debugging. However, consider adding the required shared memory size for FP8 as well, to provide a complete picture to the user.
throw std::runtime_error(
"XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but "
"device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x), FP8 needs [size] KB.");
| raise ValueError( | ||
| f"XQA MLA only supports fp8 operation on SM120/SM121 GPUs, got {query.dtype} and {kv_cache.dtype}" | ||
| f"XQA MLA on SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only, got {query.dtype} and {kv_cache.dtype}" | ||
| ) |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/mla.py (1)
764-781:⚠️ Potential issue | 🟠 MajorAdd SM capability check to direct
xqa_batch_decode_with_kv_cache_mlaentrypoint.This public API can be called directly, bypassing
trtllm_batch_decode_with_kv_cache_mla. Without an explicit capability check, unsupported devices fail later with less actionable errors. The check exists in the parent function (line 602-603) but must be replicated here since this is a public entry point.Use the existing codebase pattern for consistency:
def xqa_batch_decode_with_kv_cache_mla( @@ ) -> torch.Tensor: @@ + if get_compute_capability(query.device)[0] != 12: + raise ValueError("XQA MLA is only supported on SM120/SM121 GPUs") enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla.py` around lines 764 - 781, Add the same SM-capability guard used in the parent entrypoint to the public function xqa_batch_decode_with_kv_cache_mla: call device_support_sm(query.device) early (before dtype/block-size checks) and raise a clear ValueError if it returns False so unsupported GPUs fail fast; follow the existing pattern used in the other entrypoint (use device_support_sm and the same error messaging style).
🧹 Nitpick comments (1)
flashinfer/mla.py (1)
604-613: Consider centralizing XQA dtype-pair validation.The same
(fp8, fp8)/(bf16, bf16)check is duplicated in two API paths. A small helper would prevent drift in behavior and messaging.Also applies to: 774-781
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla.py` around lines 604 - 613, Extract the duplicated dtype-pair check (the fp8_ok/bf16_ok logic and the ValueError message that mentions "XQA MLA on SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only") into a single helper function (e.g., validate_xqa_dtype_pair(query, kv_cache)) and replace both inline blocks (the one using fp8_ok/bf16_ok shown and the other at the later location) with calls to that helper; ensure the helper performs the same boolean checks against torch.float8_e4m3fn and torch.bfloat16 and raises the identical ValueError message to keep behavior and messaging consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/xqa/mla_sm120.cu`:
- Around line 1728-1730: The checks for cudaDevAttrMaxSharedMemoryPerBlockOptin
use a hardcoded device 0; change them to query the currently active device first
and use that device id for the attribute call. Concretely, before calling
cudaDeviceGetAttribute in launchMLA() and configureKernel(), call
cudaGetDevice(&dev) (or an equivalent helper), then pass dev into
cudaDeviceGetAttribute(&devMaxShmem, cudaDevAttrMaxSharedMemoryPerBlockOptin,
dev); preserve existing error handling around checkCuda and ensure the variable
names (devMaxShmem, cudaDevAttrMaxSharedMemoryPerBlockOptin, launchMLA,
configureKernel) remain consistent.
---
Outside diff comments:
In `@flashinfer/mla.py`:
- Around line 764-781: Add the same SM-capability guard used in the parent
entrypoint to the public function xqa_batch_decode_with_kv_cache_mla: call
device_support_sm(query.device) early (before dtype/block-size checks) and raise
a clear ValueError if it returns False so unsupported GPUs fail fast; follow the
existing pattern used in the other entrypoint (use device_support_sm and the
same error messaging style).
---
Nitpick comments:
In `@flashinfer/mla.py`:
- Around line 604-613: Extract the duplicated dtype-pair check (the
fp8_ok/bf16_ok logic and the ValueError message that mentions "XQA MLA on
SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only") into a single
helper function (e.g., validate_xqa_dtype_pair(query, kv_cache)) and replace
both inline blocks (the one using fp8_ok/bf16_ok shown and the other at the
later location) with calls to that helper; ensure the helper performs the same
boolean checks against torch.float8_e4m3fn and torch.bfloat16 and raises the
identical ValueError message to keep behavior and messaging consistent.
| int devMaxShmem = 0; | ||
| checkCuda(cudaDeviceGetAttribute(&devMaxShmem, | ||
| cudaDevAttrMaxSharedMemoryPerBlockOptin, 0)); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's check the file size and then examine the specific lines
wc -l csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Examine the code around lines 1728-1730
sed -n '1720,1740p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 1167
🏁 Script executed:
# Examine the code around lines 1831-1833
sed -n '1823,1843p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 962
🏁 Script executed:
# Run the verification command from the review comment
rg -n -C2 'cudaDeviceGetAttribute\(&devMaxShmem,\s*cudaDevAttrMaxSharedMemoryPerBlockOptin,\s*0\)' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Also check for the pattern more broadly to understand context
rg -n 'devMaxShmem' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 547
🏁 Script executed:
# Get more context around the first occurrence (lambda function)
sed -n '1710,1750p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 2243
🏁 Script executed:
# Get function signature and more context for configureKernel
sed -n '1825,1855p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 1539
🏁 Script executed:
# Search for where configureKernel and the lambda are called
rg -n 'configureKernel\|hostSmemSize' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if cudaGetDevice is already used elsewhere in the file
rg -n 'cudaGetDevice' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for cudaSetDevice to understand multi-GPU context
rg -n 'cudaSetDevice\|device.*=' csrc/xqa/mla_sm120.cu | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for the kernel launch function that contains the first occurrence
rg -n 'void.*launchMLAFlashInfer\|void.*launchXQA' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check what function contains the lambda at lines 1728-1730
sed -n '1700,1715p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer
Length of output: 967
🏁 Script executed:
# Look for how this file is used in other places
fd -e cu -e h | head -20 | xargs rg -l 'launchMLAFlashInfer\|configureKernel' 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 52
🏁 Script executed:
# Check if there's a pattern for device management in the codebase
rg -n 'cudaSetDevice|getActiveDevice' csrc/Repository: flashinfer-ai/flashinfer
Length of output: 675
🏁 Script executed:
# Look at the actual function signature containing the lambda
sed -n '1655,1730p' csrc/xqa/mla_sm120.cu | tail -30Repository: flashinfer-ai/flashinfer
Length of output: 1578
Use the currently active CUDA device instead of hardcoded device 0 for SMEM capability checks.
Lines 1728-1730 and 1831-1833 query cudaDevAttrMaxSharedMemoryPerBlockOptin on device 0. On multi-GPU runs, these static initializers may execute with a different active device than intended, causing validation against the wrong GPU's capabilities and leading to incorrect pass/fail behavior.
The first check runs in a static lambda inside launchMLA() on first invocation, and the second runs at module initialization in configureKernel(), both potentially with undefined device context.
🔧 Proposed fix
- int devMaxShmem = 0;
- checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
- cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+ int devMaxShmem = 0;
+ int device = -1;
+ checkCuda(cudaGetDevice(&device));
+ checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
+ device));- int devMaxShmem = 0;
- checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
- cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+ int devMaxShmem = 0;
+ int device = -1;
+ checkCuda(cudaGetDevice(&device));
+ checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
+ device));📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int devMaxShmem = 0; | |
| checkCuda(cudaDeviceGetAttribute(&devMaxShmem, | |
| cudaDevAttrMaxSharedMemoryPerBlockOptin, 0)); | |
| int devMaxShmem = 0; | |
| int device = -1; | |
| checkCuda(cudaGetDevice(&device)); | |
| checkCuda(cudaDeviceGetAttribute(&devMaxShmem, | |
| cudaDevAttrMaxSharedMemoryPerBlockOptin, | |
| device)); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/xqa/mla_sm120.cu` around lines 1728 - 1730, The checks for
cudaDevAttrMaxSharedMemoryPerBlockOptin use a hardcoded device 0; change them to
query the currently active device first and use that device id for the attribute
call. Concretely, before calling cudaDeviceGetAttribute in launchMLA() and
configureKernel(), call cudaGetDevice(&dev) (or an equivalent helper), then pass
dev into cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); preserve existing error handling
around checkCuda and ensure the variable names (devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, launchMLA, configureKernel) remain
consistent.
There was a problem hiding this comment.
@saltyminty, thanks for confirming! Here's the fix applied to both locations:
In launchMLA() (static lambda, ~line 1728):
int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device));In configureKernel() (~line 1831):
int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device));Note that configureKernel() runs at static initialization time (static uint32_t const hostSmemSize = configureKernel();), so the active device context there may be particularly unpredictable. You may also want to ensure that hostSmemSize is not cached as a file-scope static if multi-GPU support with different GPU models is needed — or document the assumption that all devices in a multi-GPU setup are homogeneous.
🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-03-04T05:20:26.963Z
Learning: Applies to csrc/*.cu : Implement kernel launchers in `csrc/` that handle framework-specific tensor operations
Learnt from: depaulmillz
Repo: flashinfer-ai/flashinfer PR: 2738
File: csrc/group_gemm_nvfp4_groupwise_sm120.cu:101-102
Timestamp: 2026-03-11T16:24:12.197Z
Learning: In CUDA source files under csrc (e.g., csrc/group_gemm_nvfp4_groupwise_sm120.cu and similar), it is valid and intentional that int_workspace_buffer and float_workspace_buffer are allocated on the same device as input tensor a via _get_cache_buf(..., a.device), and that CUDADeviceGuard is sourced from float_workspace_buffer.device() with the stream from A.device(). Do not flag these as device inconsistencies; instead, verify actual inconsistencies elsewhere and rely on this established pattern.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2709
File: include/flashinfer/mamba/seq_chunk_cumsum.cuh:0-0
Timestamp: 2026-03-06T20:52:57.849Z
Learning: In `include/flashinfer/mamba/seq_chunk_cumsum.cuh` and `csrc/seq_chunk_cumsum.cu`, the maintainer explicitly does not want runtime validation of metadata (chunk_indices, chunk_offsets, seq_idx bounds, monotonicity) in the kernel launcher or device code because this is a high-throughput kernel. Do not suggest adding such checks. Debug-mode assertions may be acceptable but should not be pushed.
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Learnt from: xrq-phys
Repo: flashinfer-ai/flashinfer PR: 2711
File: csrc/trtllm_fmha_kernel_launcher.cu:552-563
Timestamp: 2026-03-07T06:34:53.719Z
Learning: In `csrc/trtllm_fmha_kernel_launcher.cu` (flashinfer-ai/flashinfer), dtype validation for SageAttention scaling-factor tensors (`sage_attn_sfs_q/k/p/v`) is intentionally absent. This file is a TVM FFI path (not a PyTorch extension path), and dtype validation is expected to be handled at a different layer/entry point. Do not flag missing `TVM_FFI_ICHECK_EQ(...dtype(), dl_float32)` checks for these tensors in this file.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` → `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.
Learnt from: danisereb
Repo: flashinfer-ai/flashinfer PR: 2464
File: include/flashinfer/gemm/mxfp8_gemm_template_sm100.h:148-163
Timestamp: 2026-02-04T10:08:47.455Z
Learning: In flashinfer GEMM template implementations (e.g., include/flashinfer/gemm/fp4_gemm_template_sm100.h, mxfp8_gemm_template_sm100.h), the Sm10x11xOnly architecture check wrapper uses a pattern where only thread0() prints an error message and calls __trap() when running on unsupported architectures. This pattern is intentional and working in production code, so consistency should be maintained across similar implementations.
Validated on SM121a (DGX Spark GB10). The original PR flashinfer-ai#2675 kernel produces 100% NaN output for BF16. These fixes make it fully correct (max_diff < 11 microunits vs PyTorch reference across batch sizes 1/2/4, seq lengths 128/256/512/1024). Bugs fixed: 1. Missing MLA_BF16 preprocessor flag in defines.h — BF16 MLA was compiling with FP8 INPUT_ELEM types, causing type mismatches throughout the kernel. 2. FP8-only JIT assertions in jit/xqa.py — gen_xqa_module_mla() asserted input_dtype == fp8, blocking BF16 compilation entirely. Added bf16 to allowed dtypes and set -DMLA_BF16=1 flag. 3. Q tensor map hardcoded 64B swizzle — partElemsK=64 with 2-byte BF16 = 128 bytes per partition, requiring 128B swizzle. Made swizzle dynamic based on partBytes. 4. V tensor map 256-byte box exceeds max swizzle — partElemsV=128 with BF16 = 256 bytes, but max TMA swizzle is 128B. Reduced partElemsV to 64 for BF16 (128 bytes, matches 128B swizzle). 5. Consumer .b8 ldmatrix transpose scrambles BF16 — ldmatrix_16x16_trans uses .b8 which byte-transposes, scrambling 2-byte BF16 values. Replaced with ldmatrix<true, 2> (.b16 transpose) for BF16 path. 6. Consumer OOB access rows 16-47 in 32-row buffer — FP8 has tokensPerTile=64 but BF16 has tokensPerTile=32. The consumer V loading iterated over 48 rows (warpTileNbAtomBx2=3), accessing rows beyond the 32-row buffer. Restructured BF16 consumer to iterate over V parts within the 32-row tile. 7. V buffer 4-part layout incompatible with single-part consumer — FP8 uses partElemsV=128 (4 parts per V head), but BF16 needs partElemsV=64 to fit swizzle. Adjusted V splitting to match. 8. Register pressure causes stack overflow crash — BF16 doubles register usage vs FP8 for Q cache. Reduced buffer counts to stay within register budget. 9. storeOrderedXToShmBf16 OOB WarpAcc indexing — original implementation indexed WarpAcc as src(row/2, col/2)(row%2, col%2) which doesn't match MMA accumulator layout. Rewrote to use correct MMA register mapping: src(instM, instN)(iM, iN) at row=instM*16+lane/4+iM*8. 10. Q register prefetch idxAtomBx2==2 never triggers for BF16 — the GEMM0 inner loop prefetches Q registers at idxAtomBx2==2, but BF16 has tileNbAtomBx2=2 (range 0..1), so the condition never fires. regQBuf[1..3] stay uninitialized → garbage GEMM0 → NaN softmax → NaN output. Fixed with constexpr qPrefetchAtomBx2 = min(2, tileNbAtomBx2-1). Validation results on SM121a (DGX Spark): B=1 seq=128: PASS, max_diff=0.000011, NaN=0 B=1 seq=256: PASS, max_diff=0.000006, NaN=0 B=1 seq=512: PASS, max_diff=0.000005, NaN=0 B=1 seq=1024: PASS, max_diff=0.000004, NaN=0 B=2 seq=128: PASS, max_diff=0.000011, NaN=0 B=2 seq=256: PASS, max_diff=0.000007, NaN=0 B=4 seq=128: PASS, max_diff=0.000010, NaN=0 B=4 seq=1024: PASS, max_diff=0.000005, NaN=0 Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
/bot run |
|
[SUCCESS] Pipeline #45373775: 9/20 passed |
|
Right now, in flashinfer/jit/xqa.py::gen_xqa_module_mla, kv is still asserted to be fp8. Is this an issue? It seems this will also cause relevant tests to be skipped. |
📌 Description
This PR adds BF16 support to the XQA MLA decode kernel on SM120/SM121 while respecting the 99 KB per-block shared memory limit on consumer Blackwell GPUs. It reuses the existing MLA pipeline and MMA infrastructure, tunes the BF16 tiling and buffer configuration, and adds runtime/device capability checks so that BF16 MLA is enabled on GPUs with sufficient shared memory and falls back cleanly to FA2 on devices that cannot support the full MLA configuration.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores