-
Notifications
You must be signed in to change notification settings - Fork 906
Support BF16 MLA on SM120 with shared-mem fallback #2675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -41,11 +41,23 @@ __constant__ constexpr XQAKernelType kernelType = XQAKernelType::kSM120_MLA; | |||||||||||||||||||
|
|
||||||||||||||||||||
| inline constexpr bool allowMultipleInputTokens = true; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| inline constexpr uint32_t partElemsK = 64; // @fixme: change this to 128 to save L2 traffic | ||||||||||||||||||||
| using MathElem = CacheElem; | ||||||||||||||||||||
| inline constexpr uint32_t mathElemBytes = sizeof(MathElem); | ||||||||||||||||||||
| inline constexpr bool is_fp8 = (mathElemBytes == 1); | ||||||||||||||||||||
| inline constexpr bool is_bf16 = (mathElemBytes == 2); | ||||||||||||||||||||
| // BF16: partElemsK=64, nbKBufs=2 β ~100KB, under 99KB opt-in (101376). | ||||||||||||||||||||
| inline constexpr uint32_t partElemsK = | ||||||||||||||||||||
| is_fp8 ? 64 : | ||||||||||||||||||||
| is_bf16 ? 64 : | ||||||||||||||||||||
| 64; | ||||||||||||||||||||
| inline constexpr uint32_t nbKParts = exactDiv(validElemsPerKHead, partElemsK); | ||||||||||||||||||||
| inline constexpr uint32_t nbQParts = nbKParts; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| inline constexpr uint32_t tokensPerTile = 64; | ||||||||||||||||||||
| inline constexpr uint32_t tokensPerTile = | ||||||||||||||||||||
| is_fp8 ? 64 : | ||||||||||||||||||||
| is_bf16 ? 32 : | ||||||||||||||||||||
| 64; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| inline constexpr uint32_t partElemsV = 128; | ||||||||||||||||||||
| inline constexpr uint32_t nbVSplit = 2; | ||||||||||||||||||||
| inline constexpr uint32_t gemm1V = exactDiv(validElemsPerVHead, nbVSplit); | ||||||||||||||||||||
|
|
@@ -54,12 +66,12 @@ inline constexpr uint32_t nbProducerCtasPerCga = nbVSplit; | |||||||||||||||||||
| inline constexpr uint32_t multiBlockMinNbTilesPerCta = 2; | ||||||||||||||||||||
| inline constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| using MathElem = CacheElem; | ||||||||||||||||||||
| inline constexpr uint32_t mathElemBytes = sizeof(MathElem); | ||||||||||||||||||||
| inline constexpr uint32_t grainsPerPartK = exactDiv(partElemsK * mathElemBytes, grainBytes); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| inline constexpr mmaShape kernelQmmaShape = is_fp8 ? mmaShape{16, 8, 32} : mmaShape{16, 8, 16}; | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||
|
|
||||||||||||||||||||
| inline constexpr float xScale = 1.f / kE4M3_MAX; | ||||||||||||||||||||
| __constant__ constexpr float rcpXScale = kE4M3_MAX; | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -162,16 +174,16 @@ class Mat16x32Loader { | |||||||||||||||||||
| __device__ inline Mat16x32Loader(Src const& src, uint32_t baseRow, uint32_t idxInstK, | ||||||||||||||||||||
| uint32_t r = laneId() % 16, uint32_t c = laneId() / 16) | ||||||||||||||||||||
| : src{src}, baseRow{baseRow}, idxInstK{idxInstK}, r{r}, c{c}, basePtr{getPtrRef(0)} { | ||||||||||||||||||||
| static_assert((grainBytes * srcCols * qmmaShape.m) % 1024 == 0); | ||||||||||||||||||||
| static_assert((grainBytes * srcCols * kernelQmmaShape.m) % 1024 == 0); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| __device__ inline Mat16x32 load(uint32_t idxInstM) const { | ||||||||||||||||||||
| return ldmatrix<false, 4>(getPtr(idxInstM)); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| template <uint32_t tileM> | ||||||||||||||||||||
| __device__ inline Vec<Mat16x32, exactDiv(tileM, qmmaShape.m)> loadWholeCol() const { | ||||||||||||||||||||
| uint32_t const nbInstM = exactDiv(tileM, qmmaShape.m); | ||||||||||||||||||||
| __device__ inline Vec<Mat16x32, exactDiv(tileM, kernelQmmaShape.m)> loadWholeCol() const { | ||||||||||||||||||||
| uint32_t const nbInstM = exactDiv(tileM, kernelQmmaShape.m); | ||||||||||||||||||||
| Vec<Mat16x32, nbInstM> ret; | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t i = 0; i < nbInstM; i++) { | ||||||||||||||||||||
|
|
@@ -181,13 +193,13 @@ class Mat16x32Loader { | |||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| __device__ inline LdGrain const* getPtr(uint32_t idxInstM) const { | ||||||||||||||||||||
| return checkedVal(basePtr + idxInstM * qmmaShape.m * srcCols, getPtrRef(idxInstM)); | ||||||||||||||||||||
| return checkedVal(basePtr + idxInstM * kernelQmmaShape.m * srcCols, getPtrRef(idxInstM)); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| private: | ||||||||||||||||||||
| __device__ inline LdGrain const* getPtrRef(uint32_t idxInstM) const { | ||||||||||||||||||||
| return &src.template at<true>(baseRow + idxInstM * qmmaShape.m + r, | ||||||||||||||||||||
| idxInstK * exactDiv(qmmaShape.k, grainElems) + c); | ||||||||||||||||||||
| return &src.template at<true>(baseRow + idxInstM * kernelQmmaShape.m + r, | ||||||||||||||||||||
| idxInstK * exactDiv(kernelQmmaShape.k, grainElems) + c); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Src const& src; | ||||||||||||||||||||
|
|
@@ -263,7 +275,9 @@ constexpr uint32_t multiBlockMathWarps = 8; | |||||||||||||||||||
| constexpr bool useRegQ = USE_REG_Q; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| struct SharedMemA { | ||||||||||||||||||||
| static inline constexpr uint32_t nbKBufs = 12; | ||||||||||||||||||||
| // BF16: 2 K-buffers to fit β€99KB opt-in (~100096 bytes); 3 buffers would need ~104KB (128KB arch). | ||||||||||||||||||||
| static inline constexpr uint32_t nbKBufs = | ||||||||||||||||||||
| is_fp8 ? 12 : (is_bf16 ? 2 : 12); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| static inline constexpr uint32_t regQParts = (useRegQ ? 4 : 0); | ||||||||||||||||||||
| static inline constexpr uint32_t shmQParts = nbQParts - regQParts; | ||||||||||||||||||||
|
|
@@ -587,12 +601,12 @@ struct Producer { | |||||||||||||||||||
| uint32_t const tileBaseRow = warpTile.y * warpIdx.x; | ||||||||||||||||||||
| PingPongMutex tensorCoreMutex{smem.tensorCoreMutex, grpIdx}; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| constexpr uint32_t partNbInstK = exactDiv(partElemsK, qmmaShape.k); | ||||||||||||||||||||
| constexpr uint32_t partNbInstK = exactDiv(partElemsK, kernelQmmaShape.k); | ||||||||||||||||||||
| using AtomA = Vec<uint32_t, 4>; // for 16x32 data, working as mat A of QMMA.16832 | ||||||||||||||||||||
| using RegQPartCol = Vec<AtomA, exactDiv(warpTile.y, qmmaShape.m)>; | ||||||||||||||||||||
| using RegQPartCol = Vec<AtomA, exactDiv(warpTile.y, kernelQmmaShape.m)>; | ||||||||||||||||||||
| using RegQPart = Vec<RegQPartCol, partNbInstK>; | ||||||||||||||||||||
| using RegQ = Vec<RegQPart, SharedMemA::regQParts>; | ||||||||||||||||||||
| constexpr uint32_t tileNbAtomBx2 = exactDiv(tokensPerTile, qmmaShape.n * 2); | ||||||||||||||||||||
| constexpr uint32_t tileNbAtomBx2 = exactDiv(tokensPerTile, kernelQmmaShape.n * 2); | ||||||||||||||||||||
| using AtomBx2 = Vec<uint32_t, 4>; // one AtomB is 8x32 and AtomBx2 is 16x32 | ||||||||||||||||||||
| using RegKPartCol = Vec<AtomBx2, tileNbAtomBx2>; | ||||||||||||||||||||
| using RegKPart = Vec<RegKPartCol, partNbInstK>; | ||||||||||||||||||||
|
|
@@ -656,7 +670,8 @@ struct Producer { | |||||||||||||||||||
| RegKPart regKBuf; | ||||||||||||||||||||
| regKBuf[0] = loadRegKCol(smem.k[kBarWaiter.idxBuf], 0); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| auto shouldTestWait = [](uint32_t idxInstK, uint32_t idxAtomBx2) { | ||||||||||||||||||||
| auto shouldTestWait = [partNbInstK, tileNbAtomBx2](uint32_t idxInstK, | ||||||||||||||||||||
| uint32_t idxAtomBx2) { | ||||||||||||||||||||
| return idxInstK == partNbInstK - 1 && idxAtomBx2 == tileNbAtomBx2 - 2; | ||||||||||||||||||||
| }; | ||||||||||||||||||||
| BarWaiter kBarWaiterNext = kBarWaiter.next(); | ||||||||||||||||||||
|
|
@@ -698,7 +713,7 @@ struct Producer { | |||||||||||||||||||
| for (uint32_t i = 0; i < WarpAcc::rows; i++) { | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t j = 0; j < 2; j++) { | ||||||||||||||||||||
| mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)), | ||||||||||||||||||||
| mma<MathElem>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)), | ||||||||||||||||||||
| reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]), | ||||||||||||||||||||
| reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j])); | ||||||||||||||||||||
|
Comment on lines
+716
to
718
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -749,7 +764,7 @@ struct Producer { | |||||||||||||||||||
| for (uint32_t i = 0; i < WarpAcc::rows; i++) { | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t j = 0; j < 2; j++) { | ||||||||||||||||||||
| mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)), | ||||||||||||||||||||
| mma<MathElem>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)), | ||||||||||||||||||||
| reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]), | ||||||||||||||||||||
| reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j])); | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -776,14 +791,14 @@ struct Producer { | |||||||||||||||||||
|
|
||||||||||||||||||||
| auto& xBar = smem.xBars[grpIdx]; | ||||||||||||||||||||
| bool const skipXBarWait = xBar.consumed.test_wait_parity(toParity<1>(grpIter)); | ||||||||||||||||||||
| // convert to fp8 | ||||||||||||||||||||
| ThrdRegRowMax rowSum; | ||||||||||||||||||||
| if constexpr (is_fp8) { | ||||||||||||||||||||
| WarpAcc const xF32Quant = xF32 * rcpXScale; | ||||||||||||||||||||
| // 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15 | ||||||||||||||||||||
| Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> xF8; | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t i = 0; i < WarpAcc::rows; i++) { | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t m = 0; m < exactDiv(qmmaShape.m, 8); m++) { | ||||||||||||||||||||
| for (uint32_t m = 0; m < exactDiv(kernelQmmaShape.m, 8); m++) { | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t j = 0; j < WarpAcc::cols; j += 2) { | ||||||||||||||||||||
| auto& dst = reinterpret_cast<__nv_fp8x2_e4m3(&)[2]>(xF8(i, j / 2)(m, 0)); | ||||||||||||||||||||
|
|
@@ -792,18 +807,24 @@ struct Producer { | |||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| // use tensor core to compute rowSum | ||||||||||||||||||||
| ThrdRegRowMax const rowSum = | ||||||||||||||||||||
| rowSum = | ||||||||||||||||||||
| computeRowSumFromF8 ? computeRowSumF8<warpTile.y, warpTile.x>(this_warp(), xF8) | ||||||||||||||||||||
| : computeRowSumF32<warpTile.y, warpTile.x>(this_warp(), xF32); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| // store xF8 and rowSum into L2 scratch buffer | ||||||||||||||||||||
| if (!skipXBarWait) { | ||||||||||||||||||||
| xBar.consumed.wait_parity(toParity<1>(grpIter)); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| storeRowMax<warpTile.y>(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane); | ||||||||||||||||||||
| storeRowMax<warpTile.y>(smem.x.rowSum, rowSum, tileBaseRow, lane); | ||||||||||||||||||||
| storeOrderedXToShm(smem.x.x, xF8, tileBaseRow, lane); | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| rowSum = computeRowSumF32<warpTile.y, warpTile.x>(this_warp(), xF32); | ||||||||||||||||||||
| if (!skipXBarWait) { | ||||||||||||||||||||
| xBar.consumed.wait_parity(toParity<1>(grpIter)); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| storeRowMax<warpTile.y>(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane); | ||||||||||||||||||||
| storeRowMax<warpTile.y>(smem.x.rowSum, rowSum, tileBaseRow, lane); | ||||||||||||||||||||
| storeOrderedXToShmBf16(smem.x.x, xF32, tileBaseRow, lane); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| xBar.produced.arrive(); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
@@ -816,6 +837,9 @@ struct Producer { | |||||||||||||||||||
| XBuffer& dst, | ||||||||||||||||||||
| Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src, | ||||||||||||||||||||
| uint32_t const tileBaseRow, uint32_t const lane = laneId()); | ||||||||||||||||||||
| __device__ inline void storeOrderedXToShmBf16(XBuffer& dst, WarpAcc const& src, | ||||||||||||||||||||
| uint32_t const tileBaseRow, | ||||||||||||||||||||
| uint32_t const lane = laneId()); | ||||||||||||||||||||
| }; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| __device__ inline void Producer::loadK() { | ||||||||||||||||||||
|
|
@@ -966,6 +990,29 @@ __device__ inline void Producer::storeOrderedXToShm( | |||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| __device__ inline void Producer::storeOrderedXToShmBf16(XBuffer& dst, WarpAcc const& src, | ||||||||||||||||||||
| uint32_t const tileBaseRow, | ||||||||||||||||||||
| uint32_t const lane) { | ||||||||||||||||||||
| constexpr uint32_t grainsPerRow = exactDiv(warpTile.x * sizeof(__nv_bfloat16), grainBytes); | ||||||||||||||||||||
| constexpr uint32_t totalGrains = warpTile.y * grainsPerRow; | ||||||||||||||||||||
| constexpr uint32_t grainsPerThread = exactDiv(totalGrains, 32); | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t i = 0; i < grainsPerThread; i++) { | ||||||||||||||||||||
| uint32_t const idx = lane + i * 32; | ||||||||||||||||||||
| uint32_t const row = idx / grainsPerRow; | ||||||||||||||||||||
| uint32_t const g = idx % grainsPerRow; | ||||||||||||||||||||
| if (row < warpTile.y) { | ||||||||||||||||||||
| __nv_bfloat16* p = | ||||||||||||||||||||
| reinterpret_cast<__nv_bfloat16*>(&dst.template at<true>(tileBaseRow + row, g)); | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t j = 0; j < 8; j++) { | ||||||||||||||||||||
| uint32_t const col = g * 8 + j; | ||||||||||||||||||||
| p[j] = __float2bfloat16(src(row / 2, col / 2)(row % 2, col % 2)); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| struct Consumer { | ||||||||||||||||||||
| static inline constexpr uint32_t nbMathWarps = nbMathWarpsB; | ||||||||||||||||||||
| static inline constexpr uint32_t nbMathThrds = warp_size * nbMathWarps; | ||||||||||||||||||||
|
|
@@ -1115,8 +1162,8 @@ __device__ inline void Consumer::compute() { | |||||||||||||||||||
| uint2 const tileIdx = {warpIdx.y, warpIdx.x}; | ||||||||||||||||||||
| uint2 const tileBase = {tileIdx.x * warpTile.x, tileIdx.y * warpTile.y}; | ||||||||||||||||||||
|
|
||||||||||||||||||||
| constexpr uint32_t tileNbInstK = exactDiv(tokensPerTile, qmmaShape.k); | ||||||||||||||||||||
| constexpr uint32_t warpTileNbAtomBx2 = exactDiv(warpTile.x, qmmaShape.n * 2); | ||||||||||||||||||||
| constexpr uint32_t tileNbInstK = exactDiv(tokensPerTile, kernelQmmaShape.k); | ||||||||||||||||||||
| constexpr uint32_t warpTileNbAtomBx2 = exactDiv(warpTile.x, kernelQmmaShape.n * 2); | ||||||||||||||||||||
|
|
||||||||||||||||||||
| uint32_t const lane = laneId(); | ||||||||||||||||||||
| uint32_t const idxHalf = lane / 16; | ||||||||||||||||||||
|
|
@@ -1195,19 +1242,19 @@ __device__ inline void Consumer::compute() { | |||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t idxInstK = 0; idxInstK < tileNbInstK; idxInstK++) { | ||||||||||||||||||||
| Mat16x32Loader const loaderX(xBuf, tileBase.y, idxInstK, rA, cA); | ||||||||||||||||||||
| Vec<Mat16x32, exactDiv(warpTile.y, qmmaShape.m)> const x = loaderX.loadWholeCol<warpTile.y>(); | ||||||||||||||||||||
| Vec<Mat16x32, exactDiv(warpTile.y, kernelQmmaShape.m)> const x = loaderX.loadWholeCol<warpTile.y>(); | ||||||||||||||||||||
| using AtomB = Vec<uint32_t, 2>; | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < warpTileNbAtomBx2; idxAtomBx2++) { | ||||||||||||||||||||
| auto const data = ldmatrix_16x16_trans<2>( | ||||||||||||||||||||
| &vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB)); | ||||||||||||||||||||
| &vBuf.template at<true>(kernelQmmaShape.k * idxInstK + rB, idxAtomBx2 + cB)); | ||||||||||||||||||||
| AtomB const v[2] = {data[0], data[2], data[1], data[3]}; | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t i = 0; i < WarpAcc::rows; i++) { | ||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||
| for (uint32_t j = 0; j < 2; j++) { | ||||||||||||||||||||
| #if 1 | ||||||||||||||||||||
| mma<__nv_fp8_e4m3>( | ||||||||||||||||||||
| mma<MathElem>( | ||||||||||||||||||||
| #else | ||||||||||||||||||||
| mmaF8_k32_2inst( | ||||||||||||||||||||
| #endif | ||||||||||||||||||||
|
|
@@ -1630,7 +1677,9 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha | |||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| __constant__ constexpr uint32_t smemSize = mha::max(sizeof(SharedMemA), sizeof(SharedMemB)); | ||||||||||||||||||||
| static_assert(smemSize <= 99 * 1024, "Shared memory size exceeded"); | ||||||||||||||||||||
| // BF16 with nbKBufs=2 uses ~100KB; allow up to 99KB opt-in (101376) for devices that support it. | ||||||||||||||||||||
| static constexpr uint32_t kSmemLimitBytes = is_bf16 ? 101376 : 99 * 1024; | ||||||||||||||||||||
| static_assert(smemSize <= kSmemLimitBytes, "Shared memory size exceeded"); | ||||||||||||||||||||
|
Comment on lines
+1680
to
+1682
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment indicates that BF16 uses ~100KB, but the limit is set to 99KB. This discrepancy should be resolved. Either the comment should be updated to reflect the actual usage, or the |
||||||||||||||||||||
| #endif // is_MLA | ||||||||||||||||||||
|
|
||||||||||||||||||||
| #ifndef GENERATE_CUBIN | ||||||||||||||||||||
|
|
@@ -1674,9 +1723,19 @@ void launchMLA( | |||||||||||||||||||
| throw std::runtime_error("not implemented"); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| static uint32_t const hostSmemSize = [&]() { | ||||||||||||||||||||
| // printf("smemSize = %u\n", smemSize); | ||||||||||||||||||||
| uint32_t size; | ||||||||||||||||||||
| checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); | ||||||||||||||||||||
| int devMaxShmem = 0; | ||||||||||||||||||||
| checkCuda(cudaDeviceGetAttribute(&devMaxShmem, | ||||||||||||||||||||
| cudaDevAttrMaxSharedMemoryPerBlockOptin, 0)); | ||||||||||||||||||||
|
Comment on lines
+1728
to
+1730
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: # First, let's check the file size and then examine the specific lines
wc -l csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 94 π Script executed: # Examine the code around lines 1728-1730
sed -n '1720,1740p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 1167 π Script executed: # Examine the code around lines 1831-1833
sed -n '1823,1843p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 962 π Script executed: # Run the verification command from the review comment
rg -n -C2 'cudaDeviceGetAttribute\(&devMaxShmem,\s*cudaDevAttrMaxSharedMemoryPerBlockOptin,\s*0\)' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: # Also check for the pattern more broadly to understand context
rg -n 'devMaxShmem' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 547 π Script executed: # Get more context around the first occurrence (lambda function)
sed -n '1710,1750p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 2243 π Script executed: # Get function signature and more context for configureKernel
sed -n '1825,1855p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 1539 π Script executed: # Search for where configureKernel and the lambda are called
rg -n 'configureKernel\|hostSmemSize' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: # Check if cudaGetDevice is already used elsewhere in the file
rg -n 'cudaGetDevice' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: # Look for cudaSetDevice to understand multi-GPU context
rg -n 'cudaSetDevice\|device.*=' csrc/xqa/mla_sm120.cu | head -20Repository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: # Search for the kernel launch function that contains the first occurrence
rg -n 'void.*launchMLAFlashInfer\|void.*launchXQA' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: # Check what function contains the lambda at lines 1728-1730
sed -n '1700,1715p' csrc/xqa/mla_sm120.cuRepository: flashinfer-ai/flashinfer Length of output: 967 π Script executed: # Look for how this file is used in other places
fd -e cu -e h | head -20 | xargs rg -l 'launchMLAFlashInfer\|configureKernel' 2>/dev/nullRepository: flashinfer-ai/flashinfer Length of output: 52 π Script executed: # Check if there's a pattern for device management in the codebase
rg -n 'cudaSetDevice|getActiveDevice' csrc/Repository: flashinfer-ai/flashinfer Length of output: 675 π Script executed: # Look at the actual function signature containing the lambda
sed -n '1655,1730p' csrc/xqa/mla_sm120.cu | tail -30Repository: flashinfer-ai/flashinfer Length of output: 1578 Use the currently active CUDA device instead of hardcoded device Lines 1728-1730 and 1831-1833 query The first check runs in a static lambda inside π§ Proposed fix- int devMaxShmem = 0;
- checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
- cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+ int devMaxShmem = 0;
+ int device = -1;
+ checkCuda(cudaGetDevice(&device));
+ checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
+ device));- int devMaxShmem = 0;
- checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
- cudaDevAttrMaxSharedMemoryPerBlockOptin, 0));
+ int devMaxShmem = 0;
+ int device = -1;
+ checkCuda(cudaGetDevice(&device));
+ checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
+ device));π Committable suggestion
Suggested change
π€ Prompt for AI Agents
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device));In int devMaxShmem = 0;
int device = -1;
checkCuda(cudaGetDevice(&device));
checkCuda(cudaDeviceGetAttribute(&devMaxShmem,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device));Note that π§ Learnings used |
||||||||||||||||||||
| if (size > (uint32_t)devMaxShmem) { | ||||||||||||||||||||
| throw std::runtime_error( | ||||||||||||||||||||
| "XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but " | ||||||||||||||||||||
| "device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x)."); | ||||||||||||||||||||
|
Comment on lines
+1731
to
+1734
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message is very helpful for debugging. However, consider adding the required shared memory size for FP8 as well, to provide a complete picture to the user. |
||||||||||||||||||||
| } | ||||||||||||||||||||
| checkCuda(cudaFuncSetAttribute(kernel_mha, | ||||||||||||||||||||
| cudaFuncAttributePreferredSharedMemoryCarveout, | ||||||||||||||||||||
| cudaSharedmemCarveoutMaxShared)); | ||||||||||||||||||||
| checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); | ||||||||||||||||||||
| return size; | ||||||||||||||||||||
| }(); | ||||||||||||||||||||
|
|
@@ -1768,8 +1827,19 @@ void launchMLA( | |||||||||||||||||||
|
|
||||||||||||||||||||
| static uint32_t configureKernel() { | ||||||||||||||||||||
| uint32_t size; | ||||||||||||||||||||
| cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)); | ||||||||||||||||||||
| cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size); | ||||||||||||||||||||
| checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); | ||||||||||||||||||||
| int devMaxShmem = 0; | ||||||||||||||||||||
| checkCuda(cudaDeviceGetAttribute(&devMaxShmem, | ||||||||||||||||||||
| cudaDevAttrMaxSharedMemoryPerBlockOptin, 0)); | ||||||||||||||||||||
| if (size > (uint32_t)devMaxShmem) { | ||||||||||||||||||||
| throw std::runtime_error( | ||||||||||||||||||||
| "XQA MLA kernel requires " + std::to_string(size) + " bytes shared memory per block, but " | ||||||||||||||||||||
| "device opt-in max is " + std::to_string(devMaxShmem) + ". BF16 MLA needs 128 KB (e.g. SM12x)."); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| checkCuda(cudaFuncSetAttribute(kernel_mha, | ||||||||||||||||||||
| cudaFuncAttributePreferredSharedMemoryCarveout, | ||||||||||||||||||||
| cudaSharedmemCarveoutMaxShared)); | ||||||||||||||||||||
| checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); | ||||||||||||||||||||
| return size; | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -599,13 +599,17 @@ def trtllm_batch_decode_with_kv_cache_mla( | |
| if isinstance(bmm2_scale, torch.Tensor): | ||
| assert bmm2_scale.dtype == torch.float32 | ||
| if backend == "xqa": | ||
| if ( | ||
| get_compute_capability(query.device)[0] != 12 | ||
| or query.dtype != torch.float8_e4m3fn | ||
| or kv_cache.dtype != torch.float8_e4m3fn | ||
| ): | ||
| if get_compute_capability(query.device)[0] != 12: | ||
| raise ValueError("XQA MLA is only supported on SM120/SM121 GPUs") | ||
| fp8_ok = ( | ||
| query.dtype == torch.float8_e4m3fn and kv_cache.dtype == torch.float8_e4m3fn | ||
| ) | ||
| bf16_ok = ( | ||
| query.dtype == torch.bfloat16 and kv_cache.dtype == torch.bfloat16 | ||
| ) | ||
| if not (fp8_ok or bf16_ok): | ||
| raise ValueError( | ||
| f"XQA MLA only supports fp8 operation on SM120/SM121 GPUs, got {query.dtype} and {kv_cache.dtype}" | ||
| f"XQA MLA on SM120/SM121 supports (fp8, fp8) or (bfloat16, bfloat16) only, got {query.dtype} and {kv_cache.dtype}" | ||
| ) | ||
|
Comment on lines
611
to
613
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if sinks is not None: | ||
| raise ValueError("XQA MLA does not support sinks") | ||
|
|
@@ -767,9 +771,13 @@ def xqa_batch_decode_with_kv_cache_mla( | |
| raise ValueError( | ||
| f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}" | ||
| ) | ||
| if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: | ||
| fp8_ok = ( | ||
| query.dtype == torch.float8_e4m3fn and kv_cache.dtype == torch.float8_e4m3fn | ||
| ) | ||
| bf16_ok = query.dtype == torch.bfloat16 and kv_cache.dtype == torch.bfloat16 | ||
| if not (fp8_ok or bf16_ok): | ||
| raise ValueError( | ||
| f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}" | ||
| f"XQA MLA supports (fp8, fp8) or (bfloat16, bfloat16) only, got {query.dtype} and {kv_cache.dtype}" | ||
| ) | ||
| if sinks is not None: | ||
| raise ValueError("XQA MLA does not support sinks") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditional assignment for
partElemsKbased onis_fp8andis_bf16is a good approach to optimize shared memory usage based on the data type. However, consider adding a static assertion to ensure that the chosen value ofpartElemsKandnbKBufs(defined later) results in a shared memory footprint within the 99KB limit. This will provide a compile-time check against exceeding the limit.