Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 103 additions & 33 deletions csrc/xqa/mla_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,23 @@ __constant__ constexpr XQAKernelType kernelType = XQAKernelType::kSM120_MLA;

inline constexpr bool allowMultipleInputTokens = true;

inline constexpr uint32_t partElemsK = 64; // @fixme: change this to 128 to save L2 traffic
using MathElem = CacheElem;
inline constexpr uint32_t mathElemBytes = sizeof(MathElem);
inline constexpr bool is_fp8 = (mathElemBytes == 1);
inline constexpr bool is_bf16 = (mathElemBytes == 2);
// BF16: partElemsK=64, nbKBufs=2 β†’ ~100KB, under 99KB opt-in (101376).
inline constexpr uint32_t partElemsK =
is_fp8 ? 64 :
is_bf16 ? 64 :
64;
Comment on lines +44 to +52
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conditional assignment for partElemsK based on is_fp8 and is_bf16 is a good approach to optimize shared memory usage based on the data type. However, consider adding a static assertion to ensure that the chosen value of partElemsK and nbKBufs (defined later) results in a shared memory footprint within the 99KB limit. This will provide a compile-time check against exceeding the limit.

inline constexpr uint32_t nbKParts = exactDiv(validElemsPerKHead, partElemsK);
inline constexpr uint32_t nbQParts = nbKParts;

inline constexpr uint32_t tokensPerTile = 64;
inline constexpr uint32_t tokensPerTile =
is_fp8 ? 64 :
is_bf16 ? 32 :
64;

inline constexpr uint32_t partElemsV = 128;
inline constexpr uint32_t nbVSplit = 2;
inline constexpr uint32_t gemm1V = exactDiv(validElemsPerVHead, nbVSplit);
Expand All @@ -54,12 +66,12 @@ inline constexpr uint32_t nbProducerCtasPerCga = nbVSplit;
inline constexpr uint32_t multiBlockMinNbTilesPerCta = 2;
inline constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2;

using MathElem = CacheElem;
inline constexpr uint32_t mathElemBytes = sizeof(MathElem);
inline constexpr uint32_t grainsPerPartK = exactDiv(partElemsK * mathElemBytes, grainBytes);

inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes);

inline constexpr mmaShape kernelQmmaShape = is_fp8 ? mmaShape{16, 8, 32} : mmaShape{16, 8, 16};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The kernelQmmaShape definition is dependent on is_fp8. It would be beneficial to add a comment explaining why these specific shapes are chosen for FP8 and BF16, referencing any relevant documentation or performance considerations.


inline constexpr float xScale = 1.f / kE4M3_MAX;
__constant__ constexpr float rcpXScale = kE4M3_MAX;

Expand Down Expand Up @@ -162,16 +174,16 @@ class Mat16x32Loader {
__device__ inline Mat16x32Loader(Src const& src, uint32_t baseRow, uint32_t idxInstK,
uint32_t r = laneId() % 16, uint32_t c = laneId() / 16)
: src{src}, baseRow{baseRow}, idxInstK{idxInstK}, r{r}, c{c}, basePtr{getPtrRef(0)} {
static_assert((grainBytes * srcCols * qmmaShape.m) % 1024 == 0);
static_assert((grainBytes * srcCols * kernelQmmaShape.m) % 1024 == 0);
}

__device__ inline Mat16x32 load(uint32_t idxInstM) const {
return ldmatrix<false, 4>(getPtr(idxInstM));
}

template <uint32_t tileM>
__device__ inline Vec<Mat16x32, exactDiv(tileM, qmmaShape.m)> loadWholeCol() const {
uint32_t const nbInstM = exactDiv(tileM, qmmaShape.m);
__device__ inline Vec<Mat16x32, exactDiv(tileM, kernelQmmaShape.m)> loadWholeCol() const {
uint32_t const nbInstM = exactDiv(tileM, kernelQmmaShape.m);
Vec<Mat16x32, nbInstM> ret;
#pragma unroll
for (uint32_t i = 0; i < nbInstM; i++) {
Expand All @@ -181,13 +193,13 @@ class Mat16x32Loader {
}

__device__ inline LdGrain const* getPtr(uint32_t idxInstM) const {
return checkedVal(basePtr + idxInstM * qmmaShape.m * srcCols, getPtrRef(idxInstM));
return checkedVal(basePtr + idxInstM * kernelQmmaShape.m * srcCols, getPtrRef(idxInstM));
}

private:
__device__ inline LdGrain const* getPtrRef(uint32_t idxInstM) const {
return &src.template at<true>(baseRow + idxInstM * qmmaShape.m + r,
idxInstK * exactDiv(qmmaShape.k, grainElems) + c);
return &src.template at<true>(baseRow + idxInstM * kernelQmmaShape.m + r,
idxInstK * exactDiv(kernelQmmaShape.k, grainElems) + c);
}

Src const& src;
Expand Down Expand Up @@ -263,7 +275,9 @@ constexpr uint32_t multiBlockMathWarps = 8;
constexpr bool useRegQ = USE_REG_Q;

struct SharedMemA {
static inline constexpr uint32_t nbKBufs = 12;
// BF16: 2 K-buffers to fit ≀99KB opt-in (~100096 bytes); 3 buffers would need ~104KB (128KB arch).
static inline constexpr uint32_t nbKBufs =
is_fp8 ? 12 : (is_bf16 ? 2 : 12);

static inline constexpr uint32_t regQParts = (useRegQ ? 4 : 0);
static inline constexpr uint32_t shmQParts = nbQParts - regQParts;
Expand Down Expand Up @@ -587,12 +601,12 @@ struct Producer {
uint32_t const tileBaseRow = warpTile.y * warpIdx.x;
PingPongMutex tensorCoreMutex{smem.tensorCoreMutex, grpIdx};

constexpr uint32_t partNbInstK = exactDiv(partElemsK, qmmaShape.k);
constexpr uint32_t partNbInstK = exactDiv(partElemsK, kernelQmmaShape.k);
using AtomA = Vec<uint32_t, 4>; // for 16x32 data, working as mat A of QMMA.16832
using RegQPartCol = Vec<AtomA, exactDiv(warpTile.y, qmmaShape.m)>;
using RegQPartCol = Vec<AtomA, exactDiv(warpTile.y, kernelQmmaShape.m)>;
using RegQPart = Vec<RegQPartCol, partNbInstK>;
using RegQ = Vec<RegQPart, SharedMemA::regQParts>;
constexpr uint32_t tileNbAtomBx2 = exactDiv(tokensPerTile, qmmaShape.n * 2);
constexpr uint32_t tileNbAtomBx2 = exactDiv(tokensPerTile, kernelQmmaShape.n * 2);
using AtomBx2 = Vec<uint32_t, 4>; // one AtomB is 8x32 and AtomBx2 is 16x32
using RegKPartCol = Vec<AtomBx2, tileNbAtomBx2>;
using RegKPart = Vec<RegKPartCol, partNbInstK>;
Expand Down Expand Up @@ -656,7 +670,8 @@ struct Producer {
RegKPart regKBuf;
regKBuf[0] = loadRegKCol(smem.k[kBarWaiter.idxBuf], 0);

auto shouldTestWait = [](uint32_t idxInstK, uint32_t idxAtomBx2) {
auto shouldTestWait = [partNbInstK, tileNbAtomBx2](uint32_t idxInstK,
uint32_t idxAtomBx2) {
return idxInstK == partNbInstK - 1 && idxAtomBx2 == tileNbAtomBx2 - 2;
};
BarWaiter kBarWaiterNext = kBarWaiter.next();
Expand Down Expand Up @@ -698,7 +713,7 @@ struct Producer {
for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
mma<MathElem>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]),
reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j]));
Comment on lines +716 to 718
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mma template is being called with MathElem which is a good abstraction. However, it's important to ensure that MathElem is correctly deduced or explicitly specified to match the expected type by the mma instruction. If MathElem is not correctly deduced, it could lead to unexpected behavior or performance degradation.

}
Expand Down Expand Up @@ -749,7 +764,7 @@ struct Producer {
for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
mma<MathElem>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]),
reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j]));
}
Expand All @@ -776,14 +791,14 @@ struct Producer {

auto& xBar = smem.xBars[grpIdx];
bool const skipXBarWait = xBar.consumed.test_wait_parity(toParity<1>(grpIter));
// convert to fp8
ThrdRegRowMax rowSum;
if constexpr (is_fp8) {
WarpAcc const xF32Quant = xF32 * rcpXScale;
// 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15
Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> xF8;
#pragma unroll
for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
for (uint32_t m = 0; m < exactDiv(qmmaShape.m, 8); m++) {
for (uint32_t m = 0; m < exactDiv(kernelQmmaShape.m, 8); m++) {
#pragma unroll
for (uint32_t j = 0; j < WarpAcc::cols; j += 2) {
auto& dst = reinterpret_cast<__nv_fp8x2_e4m3(&)[2]>(xF8(i, j / 2)(m, 0));
Expand All @@ -792,18 +807,24 @@ struct Producer {
}
}
}
// use tensor core to compute rowSum
ThrdRegRowMax const rowSum =
rowSum =
computeRowSumFromF8 ? computeRowSumF8<warpTile.y, warpTile.x>(this_warp(), xF8)
: computeRowSumF32<warpTile.y, warpTile.x>(this_warp(), xF32);

// store xF8 and rowSum into L2 scratch buffer
if (!skipXBarWait) {
xBar.consumed.wait_parity(toParity<1>(grpIter));
}
storeRowMax<warpTile.y>(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane);
storeRowMax<warpTile.y>(smem.x.rowSum, rowSum, tileBaseRow, lane);
storeOrderedXToShm(smem.x.x, xF8, tileBaseRow, lane);
} else {
rowSum = computeRowSumF32<warpTile.y, warpTile.x>(this_warp(), xF32);
if (!skipXBarWait) {
xBar.consumed.wait_parity(toParity<1>(grpIter));
}
storeRowMax<warpTile.y>(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane);
storeRowMax<warpTile.y>(smem.x.rowSum, rowSum, tileBaseRow, lane);
storeOrderedXToShmBf16(smem.x.x, xF32, tileBaseRow, lane);
}
xBar.produced.arrive();
}
}
Expand All @@ -816,6 +837,9 @@ struct Producer {
XBuffer& dst,
Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src,
uint32_t const tileBaseRow, uint32_t const lane = laneId());
__device__ inline void storeOrderedXToShmBf16(XBuffer& dst, WarpAcc const& src,
uint32_t const tileBaseRow,
uint32_t const lane = laneId());
};

__device__ inline void Producer::loadK() {
Expand Down Expand Up @@ -966,6 +990,29 @@ __device__ inline void Producer::storeOrderedXToShm(
}
}

__device__ inline void Producer::storeOrderedXToShmBf16(XBuffer& dst, WarpAcc const& src,
uint32_t const tileBaseRow,
uint32_t const lane) {
constexpr uint32_t grainsPerRow = exactDiv(warpTile.x * sizeof(__nv_bfloat16), grainBytes);
constexpr uint32_t totalGrains = warpTile.y * grainsPerRow;
constexpr uint32_t grainsPerThread = exactDiv(totalGrains, 32);
#pragma unroll
for (uint32_t i = 0; i < grainsPerThread; i++) {
uint32_t const idx = lane + i * 32;
uint32_t const row = idx / grainsPerRow;
uint32_t const g = idx % grainsPerRow;
if (row < warpTile.y) {
__nv_bfloat16* p =
reinterpret_cast<__nv_bfloat16*>(&dst.template at<true>(tileBaseRow + row, g));
#pragma unroll
for (uint32_t j = 0; j < 8; j++) {
uint32_t const col = g * 8 + j;
p[j] = __float2bfloat16(src(row / 2, col / 2)(row % 2, col % 2));
}
}
}
}

struct Consumer {
static inline constexpr uint32_t nbMathWarps = nbMathWarpsB;
static inline constexpr uint32_t nbMathThrds = warp_size * nbMathWarps;
Expand Down Expand Up @@ -1115,8 +1162,8 @@ __device__ inline void Consumer::compute() {
uint2 const tileIdx = {warpIdx.y, warpIdx.x};
uint2 const tileBase = {tileIdx.x * warpTile.x, tileIdx.y * warpTile.y};

constexpr uint32_t tileNbInstK = exactDiv(tokensPerTile, qmmaShape.k);
constexpr uint32_t warpTileNbAtomBx2 = exactDiv(warpTile.x, qmmaShape.n * 2);
constexpr uint32_t tileNbInstK = exactDiv(tokensPerTile, kernelQmmaShape.k);
constexpr uint32_t warpTileNbAtomBx2 = exactDiv(warpTile.x, kernelQmmaShape.n * 2);

uint32_t const lane = laneId();
uint32_t const idxHalf = lane / 16;
Expand Down Expand Up @@ -1195,19 +1242,19 @@ __device__ inline void Consumer::compute() {
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < tileNbInstK; idxInstK++) {
Mat16x32Loader const loaderX(xBuf, tileBase.y, idxInstK, rA, cA);
Vec<Mat16x32, exactDiv(warpTile.y, qmmaShape.m)> const x = loaderX.loadWholeCol<warpTile.y>();
Vec<Mat16x32, exactDiv(warpTile.y, kernelQmmaShape.m)> const x = loaderX.loadWholeCol<warpTile.y>();
using AtomB = Vec<uint32_t, 2>;
#pragma unroll
for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < warpTileNbAtomBx2; idxAtomBx2++) {
auto const data = ldmatrix_16x16_trans<2>(
&vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
&vBuf.template at<true>(kernelQmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
AtomB const v[2] = {data[0], data[2], data[1], data[3]};
#pragma unroll
for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#if 1
mma<__nv_fp8_e4m3>(
mma<MathElem>(
#else
mmaF8_k32_2inst(
#endif
Expand Down Expand Up @@ -1630,7 +1677,9 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha
}

__constant__ constexpr uint32_t smemSize = mha::max(sizeof(SharedMemA), sizeof(SharedMemB));
static_assert(smemSize <= 99 * 1024, "Shared memory size exceeded");
// BF16 with nbKBufs=2 uses ~100KB; allow up to 99KB opt-in (101376) for devices that support it.
static constexpr uint32_t kSmemLimitBytes = is_bf16 ? 101376 : 99 * 1024;
static_assert(smemSize <= kSmemLimitBytes, "Shared memory size exceeded");
Comment on lines +1680 to +1682
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment indicates that BF16 uses ~100KB, but the limit is set to 99KB. This discrepancy should be resolved. Either the comment should be updated to reflect the actual usage, or the kSmemLimitBytes value should be adjusted if the hardware allows for the slightly larger size without performance degradation. If the 100KB usage is an estimate, it should be clarified as such.

#endif // is_MLA

#ifndef GENERATE_CUBIN
Expand Down Expand Up @@ -1674,9 +1723,19 @@ void launchMLA(
throw std::runtime_error("not implemented");
}
static uint32_t const hostSmemSize = [&]() {
// printf("smemSize = %u\n", smemSize);
uint32_t size;
checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)));
int devMaxShmem = 0;
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
Comment on lines +1728 to +1730
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's check the file size and then examine the specific lines
wc -l csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

# Examine the code around lines 1728-1730
sed -n '1720,1740p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1167


🏁 Script executed:

# Examine the code around lines 1831-1833
sed -n '1823,1843p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 962


🏁 Script executed:

# Run the verification command from the review comment
rg -n -C2 'cudaDeviceGetAttribute\(&devMaxShmem,\s*cudaDevAttrMaxSharedMemoryPerBlockOptin,\s*0\)' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Also check for the pattern more broadly to understand context
rg -n 'devMaxShmem' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 547


🏁 Script executed:

# Get more context around the first occurrence (lambda function)
sed -n '1710,1750p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 2243


🏁 Script executed:

# Get function signature and more context for configureKernel
sed -n '1825,1855p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 1539


🏁 Script executed:

# Search for where configureKernel and the lambda are called
rg -n 'configureKernel\|hostSmemSize' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if cudaGetDevice is already used elsewhere in the file
rg -n 'cudaGetDevice' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for cudaSetDevice to understand multi-GPU context
rg -n 'cudaSetDevice\|device.*=' csrc/xqa/mla_sm120.cu | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for the kernel launch function that contains the first occurrence
rg -n 'void.*launchMLAFlashInfer\|void.*launchXQA' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check what function contains the lambda at lines 1728-1730
sed -n '1700,1715p' csrc/xqa/mla_sm120.cu

Repository: flashinfer-ai/flashinfer

Length of output: 967


🏁 Script executed:

# Look for how this file is used in other places
fd -e cu -e h | head -20 | xargs rg -l 'launchMLAFlashInfer\|configureKernel' 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 52


🏁 Script executed:

# Check if there's a pattern for device management in the codebase
rg -n 'cudaSetDevice|getActiveDevice' csrc/

Repository: flashinfer-ai/flashinfer

Length of output: 675


🏁 Script executed:

# Look at the actual function signature containing the lambda
sed -n '1655,1730p' csrc/xqa/mla_sm120.cu | tail -30

Repository: flashinfer-ai/flashinfer

Length of output: 1578


Use the currently active CUDA device instead of hardcoded device 0 for SMEM capability checks.

Lines 1728-1730 and 1831-1833 query cudaDevAttrMaxSharedMemoryPerBlockOptin on device 0. On multi-GPU runs, these static initializers may execute with a different active device than intended, causing validation against the wrong GPU's capabilities and leading to incorrect pass/fail behavior.

The first check runs in a static lambda inside launchMLA() on first invocation, and the second runs at module initialization in configureKernel(), both potentially with undefined device context.

πŸ”§ Proposed fix
-    int devMaxShmem = 0;
-    checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
-                                      cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+    int devMaxShmem = 0;
+    int device = -1;
+    checkCuda(cudaGetDevice(&device));
+    checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+                                     cudaDevAttrMaxSharedMemoryPerBlockOptin,
+                                     device));
-  int devMaxShmem = 0;
-  checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
-                                    cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+  int devMaxShmem = 0;
+  int device = -1;
+  checkCuda(cudaGetDevice(&device));
+  checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+                                   cudaDevAttrMaxSharedMemoryPerBlockOptin,
+                                   device));
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int devMaxShmem = 0;
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device));
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/xqa/mla_sm120.cu` around lines 1728 - 1730, The checks for
cudaDevAttrMaxSharedMemoryPerBlockOptin use a hardcoded device 0; change them to
query the currently active device first and use that device id for the attribute
call. Concretely, before calling cudaDeviceGetAttribute in launchMLA() and
configureKernel(), call cudaGetDevice(&dev) (or an equivalent helper), then pass
dev into cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); preserve existing error handling
around checkCuda and ensure the variable names (devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, launchMLA, configureKernel) remain
consistent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saltyminty, thanks for confirming! Here's the fix applied to both locations:

In launchMLA() (static lambda, ~line 1728):

int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
                                 cudaDevAttrMaxSharedMemoryPerBlockOptin,
                                 device));

In configureKernel() (~line 1831):

int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
                                 cudaDevAttrMaxSharedMemoryPerBlockOptin,
                                 device));

Note that configureKernel() runs at static initialization time (static uint32_t const hostSmemSize = configureKernel();), so the active device context there may be particularly unpredictable. You may also want to ensure that hostSmemSize is not cached as a file-scope static if multi-GPU support with different GPU models is needed β€” or document the assumption that all devices in a multi-GPU setup are homogeneous.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-03-04T05:20:26.963Z
Learning: Applies to csrc/*.cu : Implement kernel launchers in `csrc/` that handle framework-specific tensor operations

Learnt from: depaulmillz
Repo: flashinfer-ai/flashinfer PR: 2738
File: csrc/group_gemm_nvfp4_groupwise_sm120.cu:101-102
Timestamp: 2026-03-11T16:24:12.197Z
Learning: In CUDA source files under csrc (e.g., csrc/group_gemm_nvfp4_groupwise_sm120.cu and similar), it is valid and intentional that int_workspace_buffer and float_workspace_buffer are allocated on the same device as input tensor a via _get_cache_buf(..., a.device), and that CUDADeviceGuard is sourced from float_workspace_buffer.device() with the stream from A.device(). Do not flag these as device inconsistencies; instead, verify actual inconsistencies elsewhere and rely on this established pattern.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2709
File: include/flashinfer/mamba/seq_chunk_cumsum.cuh:0-0
Timestamp: 2026-03-06T20:52:57.849Z
Learning: In `include/flashinfer/mamba/seq_chunk_cumsum.cuh` and `csrc/seq_chunk_cumsum.cu`, the maintainer explicitly does not want runtime validation of metadata (chunk_indices, chunk_offsets, seq_idx bounds, monotonicity) in the kernel launcher or device code because this is a high-throughput kernel. Do not suggest adding such checks. Debug-mode assertions may be acceptable but should not be pushed.

Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Learnt from: xrq-phys
Repo: flashinfer-ai/flashinfer PR: 2711
File: csrc/trtllm_fmha_kernel_launcher.cu:552-563
Timestamp: 2026-03-07T06:34:53.719Z
Learning: In `csrc/trtllm_fmha_kernel_launcher.cu` (flashinfer-ai/flashinfer), dtype validation for SageAttention scaling-factor tensors (`sage_attn_sfs_q/k/p/v`) is intentionally absent. This file is a TVM FFI path (not a PyTorch extension path), and dtype validation is expected to be handled at a different layer/entry point. Do not flag missing `TVM_FFI_ICHECK_EQ(...dtype(), dl_float32)` checks for these tensors in this file.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2591
File: flashinfer/aot.py:588-599
Timestamp: 2026-02-19T21:59:36.542Z
Learning: When reviewing changes to conditional blocks (e.g., `if has_sm90:` β†’ `if has_sm90 or has_sm100:`), distinguish between code the PR author wrote versus pre-existing code that happens to be in the modified block. Do not ask the PR author to fix potential issues in pre-existing code unless it's directly related to their changes.

Learnt from: danisereb
Repo: flashinfer-ai/flashinfer PR: 2464
File: include/flashinfer/gemm/mxfp8_gemm_template_sm100.h:148-163
Timestamp: 2026-02-04T10:08:47.455Z
Learning: In flashinfer GEMM template implementations (e.g., include/flashinfer/gemm/fp4_gemm_template_sm100.h, mxfp8_gemm_template_sm100.h), the Sm10x11xOnly architecture check wrapper uses a pattern where only thread0() prints an error message and calls __trap() when running on unsupported architectures. This pattern is intentional and working in production code, so consistency should be maintained across similar implementations.

if (size > (uint32_t)devMaxShmem) {
throw std::runtime_error(
"XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but "
"device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x).");
Comment on lines +1731 to +1734
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message is very helpful for debugging. However, consider adding the required shared memory size for FP8 as well, to provide a complete picture to the user.

throw std::runtime_error(
          "XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but "
          "device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x), FP8 needs [size] KB.");

}
checkCuda(cudaFuncSetAttribute(kernel_mha,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared));
checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size));
return size;
}();
Expand Down Expand Up @@ -1768,8 +1827,19 @@ void launchMLA(

static uint32_t configureKernel() {
uint32_t size;
cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize));
cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size);
checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)));
int devMaxShmem = 0;
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
if (size > (uint32_t)devMaxShmem) {
throw std::runtime_error(
"XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but "
"device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x).");
}
checkCuda(cudaFuncSetAttribute(kernel_mha,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxShared));
checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size));
return size;
}

Expand Down
24 changes: 16 additions & 8 deletions flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,17 @@ def trtllm_batch_decode_with_kv_cache_mla(
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
if backend == "xqa":
if (
get_compute_capability(query.device)[0] != 12
or query.dtype != torch.float8_e4m3fn
or kv_cache.dtype != torch.float8_e4m3fn
):
if get_compute_capability(query.device)[0] != 12:
raise ValueError("XQA MLA is only supported on SM120/SM121 GPUs")
fp8_ok = (
query.dtype == torch.float8_e4m3fn and kv_cache.dtype == torch.float8_e4m3fn
)
bf16_ok = (
query.dtype == torch.bfloat16 and kv_cache.dtype == torch.bfloat16
)
if not (fp8_ok or bf16_ok):
raise ValueError(
f"XQA MLA only supports fp8 operation on SM120/SM121 GPUs, got {query.dtype} and {kv_cache.dtype}"
f"XQA MLA on SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only, got {query.dtype} and {kv_cache.dtype}"
)
Comment on lines 611 to 613
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message could be improved by explicitly stating the expected data types for query and kv_cache when BF16 is enabled. This will help users quickly identify the correct data type configuration.

raise ValueError(
                f

if sinks is not None:
raise ValueError("XQA MLA does not support sinks")
Expand Down Expand Up @@ -767,9 +771,13 @@ def xqa_batch_decode_with_kv_cache_mla(
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}"
)
if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
fp8_ok = (
query.dtype == torch.float8_e4m3fn and kv_cache.dtype == torch.float8_e4m3fn
)
bf16_ok = query.dtype == torch.bfloat16 and kv_cache.dtype == torch.bfloat16
if not (fp8_ok or bf16_ok):
raise ValueError(
f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}"
f"XQA MLA supports (fp8, fp8) or (bfloat16, bfloat16) only, got {query.dtype} and {kv_cache.dtype}"
)
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")
Expand Down