-
Notifications
You must be signed in to change notification settings - Fork 896
feat: update trtllm-gen MoE cubins #2416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,6 +96,9 @@ struct BatchedGemmData { | |
| struct InputBuffers { | ||
| // The matrix A. The data type is controlled by options.mDtypeA. | ||
| // | ||
| // Sparsity is only supported with batchN. | ||
| // Let S be the sparsity ratio (1 for dense, 2 for sparse). | ||
| // | ||
| // If (routeAct == true && batchM), the shape is [M, K] | ||
| // Elseif (batchStrideInTokens > 0) | ||
| // If batchM: | ||
|
|
@@ -104,23 +107,23 @@ struct BatchedGemmData { | |
| // | ||
| // If batchN: | ||
| // If layoutA is MatrixLayout::MajorK | ||
| // Logical shape is [B, divUpMul(M, tileM), K]. | ||
| // Logical strides are [divUpMul(M, tileM) * K, K, 1]. | ||
| // If layoutA is MatrixLayout::MajorMn | ||
| // Logical shape is [B, divUpMul(M, tileM), K / S]. | ||
| // Logical strides are [divUpMul(M, tileM) * K / S, K / S, 1]. | ||
| // If layoutA is MatrixLayout::MajorMn (sparsity not supported) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| // Logical shape is [B, K, divUpMul(M, tileM)]. | ||
| // Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM), 1]. | ||
| // If layoutA is MatrixLayout::BlockMajorK | ||
| // Logical shape is [B, K / blockK, divUpMul(M, tileM), blockK]. | ||
| // Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, 1]. | ||
| // where blockK is 128B. | ||
| // Logical shape is [B, K / S / blockK, divUpMul(M, tileM), blockK]. | ||
| // Logical strides are [K / S * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, | ||
| // 1]. where blockK is 128B. | ||
| // Else // batchStrideInTokens == 0 | ||
| // If batchM: | ||
| // If batchM: (sparsity not supported) | ||
| // Logical shape is [M, K]. | ||
| // Logical strides are [K, 1]. | ||
| // | ||
| // If batchN: | ||
| // Logical shape is [B, divUpMul(M, tileM), K]. | ||
| // Logical strides are [divUpMul(M, tileM) * K, K, 1]. | ||
| // Logical shape is [B, divUpMul(M, tileM), K / S]. | ||
| // Logical strides are [divUpMul(M, tileM) * K / S, K / S, 1]. | ||
| void const* mPtrA{nullptr}; | ||
|
|
||
| // The block scaling factors to dequantize A. | ||
|
|
@@ -136,9 +139,9 @@ struct BatchedGemmData { | |
| // The layout of scaling factors for A is always R128c4 | ||
| // M must be a multiple of 128. | ||
| // K must be a multiple of 64. | ||
| // The "logical" shape is: [paddedM, K / 16]. | ||
| // The R128c4 layout is: [paddedM / 128, K / 16 / 4, 512]. | ||
| // The shape we use for TMA is: [paddedM / 128, K / 16 / 4, 2, 256]. | ||
| // The "logical" shape is: [paddedM, K / P], where P is the scaling block size. | ||
| // The R128c4 layout is: [paddedM / 128, K / P / 4, 512]. | ||
| // The shape we use for TMA is: [paddedM / 128, K / P / 4, 2, 256]. | ||
| // Where paddedM is M if (routeAct == true && batchM), or | ||
| // sum(divUpMul(M[bi], tileM) for bi in B) if batchM, | ||
| // otherwise divUpMul(M, tileM) * B. | ||
|
|
@@ -209,15 +212,15 @@ struct BatchedGemmData { | |
| // If the layout is R128c4, | ||
| // paddedN must be a multiple of 128. | ||
| // K must be a multiple of 64. | ||
| // The R128c4 layout is: [paddedN / 128, K / 16 / 4, 512] | ||
| // The shape we use for TMA is: [paddedN / 128, K / 16 / 4, 2, 256] | ||
| // The R128c4 layout is: [paddedN / 128, K / P / 4, 512], where P is the scaling block | ||
| // size. The shape we use for TMA is: [paddedN / 128, K / P / 4, 2, 256] | ||
| // | ||
| // If the layout is R8c4, | ||
| // paddedN must be a multiple of 8. | ||
| // K must be a multiple of 64. | ||
| // The R8c4 layout is: [paddedN / 8, K / 16 / 4, 32] | ||
| // The shape we use for TMA is: [paddedN / 8, K / 16 / 4 / repeats, repeats * 32] | ||
| // where repeats = min(tileK / 16 / 4, 8) | ||
| // The R8c4 layout is: [paddedN / 8, K / P / 4, 32], where P is the scaling block size. | ||
| // The shape we use for TMA is: [paddedN / 8, K / P / 4 / repeats, repeats * 32] | ||
| // where repeats = min(tileK / P / 4, 8) | ||
| // | ||
| // where paddedN is N if (routeAct == true && batchN), | ||
| // or sum(divUpMul(N[bi], tileN) for bi in B) if batchN, | ||
|
|
@@ -243,6 +246,25 @@ struct BatchedGemmData { | |
| // Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B)] | ||
| void const* mPtrPerTokenSfB{nullptr}; | ||
|
|
||
| // The sparsity information of A, if structured sparsity is used. | ||
| // Only supported for batchN (A is weights). | ||
| // | ||
| // When sparsityA is Any_2_4: | ||
| // 2 elements are non-zero in any chunk of 4 elements. | ||
| // A 4-bit index indicates the position of the non-zero elements. | ||
| // The shape in Uint8 is: [B, divUpMul(M, tileM), K / 8] | ||
| // (two 4-bit indices packed into one UInt8) | ||
| // | ||
| // When sparsityA is Pairwise_4_8: | ||
| // 4 elements are non-zero in any chunk of 8 elements. | ||
| // The zero and non-zero elements are grouped in pairs. | ||
| // A 4-bit index indicates the position of the non-zero pairs. | ||
| // The shape in Uint8 is: [B, divUpMul(M, tileM), K / 16] | ||
| // (two 4-bit indices packed into one UInt8) | ||
| // | ||
| // If sparsityA is Dense, this should be set to nullptr. | ||
| void const* mPtrSparsityInfoA{nullptr}; | ||
|
|
||
| // The bias applied after the GEMM and before the activation function. | ||
| // The bias is applied before applying the global scaling factor. I.e. | ||
| // C = act(A * B + bias') * scaleC | ||
|
|
@@ -381,7 +403,8 @@ struct BatchedGemmData { | |
| // Computed as | ||
| // int32_t totalNumPaddedTokens{0}; | ||
| // for (int bi = 0; bi < options.mNumBatches; bi++) { | ||
| // totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM) | ||
| // totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM * | ||
| // options.mClusterDimX) | ||
|
Comment on lines
+406
to
+407
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modifying the // totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX) |
||
| // : divUpMul(options.mBatchedN[bi], options.mTileN); | ||
| // } | ||
| // The size is 1 and the dtype is int32_t. | ||
|
|
@@ -430,26 +453,26 @@ struct BatchedGemmData { | |
| // The output block scaling factors for C. | ||
| // | ||
| // If MxFp{4,8} and NvFp4 formats are used, | ||
| // The "logical" shape is: | ||
| // if batchM: [paddedM, N / 16] | ||
| // if batchN: [paddedN, M / 16] | ||
| // The "logical" shape is (P is the scaling block size): | ||
| // if batchM: [paddedM, N / P] | ||
| // if batchN: [paddedN, M / P] | ||
| // where paddedM is sum(divUpMul(M[bi], tileM) for bi in B), | ||
| // where paddedN is sum(divUpMul(N[bi], tileN) for bi in B). | ||
| // | ||
| // If the layout is R128c4, | ||
| // paddedOuter must be a multiple of 128. | ||
| // inner must be a multiple of 64. | ||
| // The R128c4 layout is: [paddedOuter / 128, inner / 16 / 4, 512] | ||
| // The shape we use for TMA is: [paddedOuter / 128, inner / 16 / 4, 2, 256] | ||
| // The R128c4 layout is: [paddedOuter / 128, inner / P / 4, 512] | ||
| // The shape we use for TMA is: [paddedOuter / 128, inner / P / 4, 2, 256] | ||
| // where inner = N if batchM, otherwise M. | ||
| // where paddedOuter = paddedM if batchM, otherwise paddedN. | ||
| // | ||
| // If the layout is R8c4, | ||
| // paddedOuter must be a multiple of 8. | ||
| // inner must be a multiple of 64. | ||
| // The R8c4 layout is: [paddedOuter / 8, inner / 16 / 4, 32] | ||
| // The shape we use for TMA is: [paddedOuter / 8, inner / 16 / 4 / repeats, repeats * 32] | ||
| // where repeats = min(tileInner / 16 / 4, 8), | ||
| // The R8c4 layout is: [paddedOuter / 8, inner / P / 4, 32] | ||
| // The shape we use for TMA is: [paddedOuter / 8, inner / P / 4 / repeats, repeats * 32] | ||
| // where repeats = min(tileInner / P / 4, 8), | ||
| // where tileInner = tileN if batchM, otherwise tileM, | ||
| // where paddedOuter = paddedM if batchM, otherwise paddedN. | ||
| // where inner = N if batchM, otherwise M. | ||
|
|
@@ -528,11 +551,13 @@ class BatchedGemmInterface { | |
|
|
||
| auto [numCtaBatch, numCtaTile, numCtaInner] = | ||
| getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); | ||
|
|
||
| auto kernelParams = KernelParamsSetup::setKernelParams( | ||
| options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, | ||
| batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, | ||
| batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, | ||
| batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, | ||
| batchedGemmData.mInputBuffers.mPtrPerTokenSfB, | ||
| batchedGemmData.mInputBuffers.mPtrSparsityInfoA, batchedGemmData.mInputBuffers.mPtrBias, | ||
| batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, | ||
| batchedGemmData.mInputBuffers.mPtrScaleAct, batchedGemmData.mInputBuffers.mPtrScaleGate, | ||
| batchedGemmData.mInputBuffers.mPtrClampLimit, | ||
|
|
@@ -544,8 +569,7 @@ class BatchedGemmInterface { | |
| batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); | ||
|
|
||
| // The size of the grid. | ||
| std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner} | ||
| : std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner}; | ||
| auto grid = getLaunchGrid(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); | ||
|
|
||
| BatchedGemmConfig batchedGemmConfig = config; | ||
| #ifndef TLLM_GEN_EXPORT_INTERFACE | ||
|
|
@@ -676,8 +700,10 @@ class BatchedGemmInterface { | |
| // For normal BMM, mNumTokens == 0 and the number of CTAs is known to host. | ||
| if (options.mIsStaticBatch) { | ||
| for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { | ||
| numCtasBatch += batchM ? gemm::divUp(options.mBatchedM[bi], options.mTileM) | ||
| : gemm::divUp(options.mBatchedN[bi], options.mTileN); | ||
| numCtasBatch += | ||
| batchM ? gemm::divUp(options.mBatchedM[bi], options.mTileM * options.mClusterDimX) * | ||
| options.mClusterDimX | ||
| : gemm::divUp(options.mBatchedN[bi], options.mTileN); | ||
| } | ||
| } | ||
| // For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime. | ||
|
|
@@ -711,11 +737,24 @@ class BatchedGemmInterface { | |
|
|
||
| ////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| // Returns the number of CTAs of the current kernel. | ||
| std::vector<int32_t> getLaunchGrid( | ||
| BatchedGemmOptions const& options, | ||
| std::optional<int32_t> maxNumCtasInBatchDim = std::nullopt) const { | ||
| auto [numCtaBatch, numCtaTile, numCtaInner] = getGridDim(options, maxNumCtasInBatchDim); | ||
| bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; | ||
| std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner} | ||
| : std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner}; | ||
| return grid; | ||
| } | ||
|
|
||
| ////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| // Returns the number of CTAs of the current kernel. | ||
| int32_t getNumCtas(BatchedGemmOptions const& options, | ||
| std::optional<int32_t> maxNumCtasInBatchDim = std::nullopt) const { | ||
| auto [numCtasBatch, numCtasTile, numCtasInner] = getGridDim(options, maxNumCtasInBatchDim); | ||
| return numCtasBatch * numCtasTile * numCtasInner; | ||
| auto grid = getLaunchGrid(options, maxNumCtasInBatchDim); | ||
| return grid[0] * grid[1] * grid[2]; | ||
| } | ||
|
|
||
| ////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
@@ -779,19 +818,20 @@ class BatchedGemmInterface { | |
| auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; | ||
| if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { | ||
| for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { | ||
| totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) | ||
| : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); | ||
| totalNumPaddedTokens += | ||
| batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX) | ||
| : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); | ||
| } | ||
| } else { | ||
| // Get tile in token dim. | ||
| auto tileTokensDim = batchM ? options.mTileM : options.mTileN; | ||
| auto tileTokensDim = batchM ? options.mTileM * options.mClusterDimX : options.mTileN; | ||
| totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; | ||
| } | ||
|
|
||
| // Get options from config. | ||
| auto& options = config.mOptions; | ||
|
|
||
| int const tokenTile = batchM ? options.mTileM : options.mTileN; | ||
| int const tokenTile = batchM ? options.mTileM * options.mClusterDimX : options.mTileN; | ||
|
|
||
| auto const numTokens = totalNumPaddedTokens; | ||
| auto const intermediateDim = batchM ? options.mN : options.mM; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The introduction of
K / Sin the logical shape forMatrixLayout::MajorKwhenbatchNis a critical change for correctly representing the dimensions of sparse matrices. This ensures that the effective K dimension is properly accounted for.