Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
tileSize == mOptions.tileSize &&
options.mUseShuffledMatrixA == mOptions.useShuffledMatrixA &&
options.mUseShuffledMatrix == mOptions.useShuffledMatrixA &&
options.mLayoutA == mOptions.weightLayout) {
if (options.mFusedAct) {
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType)) {
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ArtifactPath:

TRTLLM_GEN_FMHA: str = "75d477a640f268ea9ad117cc596eb39245713b9e/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"fea3b0ecfe11d7b34556042aeb5d0465ad101500/batched_gemm-332ffef-9936841"
"e1e11bbfe0743743620ef997a6d5e8e2dbdf01cf/batched_gemm-2a674db-79e4d37"
)
TRTLLM_GEN_GEMM: str = (
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
Expand All @@ -110,7 +110,7 @@ class CheckSumHash:
"e014d7a54c396733ef012b223603c1be2861019f88faa5dcc882ed1ecfe5c2d9"
)
TRTLLM_GEN_BMM: str = (
"1c3c7ae0755a0acb7ad35da7dbdb90ab71c253dc289051faa2b4e3180dfc4b23"
"03b1a419b594b7a4613ea8437c172dc2627d56bd360be25aa604859dc12a05fb"
)
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ struct BatchedGemmData {
struct InputBuffers {
// The matrix A. The data type is controlled by options.mDtypeA.
//
// Sparsity is only supported with batchN.
// Let S be the sparsity ratio (1 for dense, 2 for sparse).
//
// If (routeAct == true && batchM), the shape is [M, K]
// Elseif (batchStrideInTokens > 0)
// If batchM:
Expand All @@ -104,23 +107,23 @@ struct BatchedGemmData {
//
// If batchN:
// If layoutA is MatrixLayout::MajorK
// Logical shape is [B, divUpMul(M, tileM), K].
// Logical strides are [divUpMul(M, tileM) * K, K, 1].
// If layoutA is MatrixLayout::MajorMn
// Logical shape is [B, divUpMul(M, tileM), K / S].
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The introduction of K / S in the logical shape for MatrixLayout::MajorK when batchN is a critical change for correctly representing the dimensions of sparse matrices. This ensures that the effective K dimension is properly accounted for.

         Logical shape is [B, divUpMul(M, tileM), K / S].

// Logical strides are [divUpMul(M, tileM) * K / S, K / S, 1].
// If layoutA is MatrixLayout::MajorMn (sparsity not supported)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding (sparsity not supported) to the MatrixLayout::MajorMn description is important for clarity and correctness, explicitly stating the limitations of sparsity with this layout.

      If layoutA is MatrixLayout::MajorMn (sparsity not supported)

// Logical shape is [B, K, divUpMul(M, tileM)].
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM), 1].
// If layoutA is MatrixLayout::BlockMajorK
// Logical shape is [B, K / blockK, divUpMul(M, tileM), blockK].
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, 1].
// where blockK is 128B.
// Logical shape is [B, K / S / blockK, divUpMul(M, tileM), blockK].
// Logical strides are [K / S * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK,
// 1]. where blockK is 128B.
// Else // batchStrideInTokens == 0
// If batchM:
// If batchM: (sparsity not supported)
// Logical shape is [M, K].
// Logical strides are [K, 1].
//
// If batchN:
// Logical shape is [B, divUpMul(M, tileM), K].
// Logical strides are [divUpMul(M, tileM) * K, K, 1].
// Logical shape is [B, divUpMul(M, tileM), K / S].
// Logical strides are [divUpMul(M, tileM) * K / S, K / S, 1].
void const* mPtrA{nullptr};

// The block scaling factors to dequantize A.
Expand All @@ -136,9 +139,9 @@ struct BatchedGemmData {
// The layout of scaling factors for A is always R128c4
// M must be a multiple of 128.
// K must be a multiple of 64.
// The "logical" shape is: [paddedM, K / 16].
// The R128c4 layout is: [paddedM / 128, K / 16 / 4, 512].
// The shape we use for TMA is: [paddedM / 128, K / 16 / 4, 2, 256].
// The "logical" shape is: [paddedM, K / P], where P is the scaling block size.
// The R128c4 layout is: [paddedM / 128, K / P / 4, 512].
// The shape we use for TMA is: [paddedM / 128, K / P / 4, 2, 256].
// Where paddedM is M if (routeAct == true && batchM), or
// sum(divUpMul(M[bi], tileM) for bi in B) if batchM,
// otherwise divUpMul(M, tileM) * B.
Expand Down Expand Up @@ -209,15 +212,15 @@ struct BatchedGemmData {
// If the layout is R128c4,
// paddedN must be a multiple of 128.
// K must be a multiple of 64.
// The R128c4 layout is: [paddedN / 128, K / 16 / 4, 512]
// The shape we use for TMA is: [paddedN / 128, K / 16 / 4, 2, 256]
// The R128c4 layout is: [paddedN / 128, K / P / 4, 512], where P is the scaling block
// size. The shape we use for TMA is: [paddedN / 128, K / P / 4, 2, 256]
//
// If the layout is R8c4,
// paddedN must be a multiple of 8.
// K must be a multiple of 64.
// The R8c4 layout is: [paddedN / 8, K / 16 / 4, 32]
// The shape we use for TMA is: [paddedN / 8, K / 16 / 4 / repeats, repeats * 32]
// where repeats = min(tileK / 16 / 4, 8)
// The R8c4 layout is: [paddedN / 8, K / P / 4, 32], where P is the scaling block size.
// The shape we use for TMA is: [paddedN / 8, K / P / 4 / repeats, repeats * 32]
// where repeats = min(tileK / P / 4, 8)
//
// where paddedN is N if (routeAct == true && batchN),
// or sum(divUpMul(N[bi], tileN) for bi in B) if batchN,
Expand All @@ -243,6 +246,25 @@ struct BatchedGemmData {
// Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B)]
void const* mPtrPerTokenSfB{nullptr};

// The sparsity information of A, if structured sparsity is used.
// Only supported for batchN (A is weights).
//
// When sparsityA is Any_2_4:
// 2 elements are non-zero in any chunk of 4 elements.
// A 4-bit index indicates the position of the non-zero elements.
// The shape in Uint8 is: [B, divUpMul(M, tileM), K / 8]
// (two 4-bit indices packed into one UInt8)
//
// When sparsityA is Pairwise_4_8:
// 4 elements are non-zero in any chunk of 8 elements.
// The zero and non-zero elements are grouped in pairs.
// A 4-bit index indicates the position of the non-zero pairs.
// The shape in Uint8 is: [B, divUpMul(M, tileM), K / 16]
// (two 4-bit indices packed into one UInt8)
//
// If sparsityA is Dense, this should be set to nullptr.
void const* mPtrSparsityInfoA{nullptr};

// The bias applied after the GEMM and before the activation function.
// The bias is applied before applying the global scaling factor. I.e.
// C = act(A * B + bias') * scaleC
Expand Down Expand Up @@ -381,7 +403,8 @@ struct BatchedGemmData {
// Computed as
// int32_t totalNumPaddedTokens{0};
// for (int bi = 0; bi < options.mNumBatches; bi++) {
// totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM)
// totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM *
// options.mClusterDimX)
Comment on lines +406 to +407
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying the totalNumPaddedTokens calculation to include options.mClusterDimX is important for correctly handling multi-CTA GEMM configurations. This ensures accurate padding and token counting in such scenarios.

    //   totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX)

// : divUpMul(options.mBatchedN[bi], options.mTileN);
// }
// The size is 1 and the dtype is int32_t.
Expand Down Expand Up @@ -430,26 +453,26 @@ struct BatchedGemmData {
// The output block scaling factors for C.
//
// If MxFp{4,8} and NvFp4 formats are used,
// The "logical" shape is:
// if batchM: [paddedM, N / 16]
// if batchN: [paddedN, M / 16]
// The "logical" shape is (P is the scaling block size):
// if batchM: [paddedM, N / P]
// if batchN: [paddedN, M / P]
// where paddedM is sum(divUpMul(M[bi], tileM) for bi in B),
// where paddedN is sum(divUpMul(N[bi], tileN) for bi in B).
//
// If the layout is R128c4,
// paddedOuter must be a multiple of 128.
// inner must be a multiple of 64.
// The R128c4 layout is: [paddedOuter / 128, inner / 16 / 4, 512]
// The shape we use for TMA is: [paddedOuter / 128, inner / 16 / 4, 2, 256]
// The R128c4 layout is: [paddedOuter / 128, inner / P / 4, 512]
// The shape we use for TMA is: [paddedOuter / 128, inner / P / 4, 2, 256]
// where inner = N if batchM, otherwise M.
// where paddedOuter = paddedM if batchM, otherwise paddedN.
//
// If the layout is R8c4,
// paddedOuter must be a multiple of 8.
// inner must be a multiple of 64.
// The R8c4 layout is: [paddedOuter / 8, inner / 16 / 4, 32]
// The shape we use for TMA is: [paddedOuter / 8, inner / 16 / 4 / repeats, repeats * 32]
// where repeats = min(tileInner / 16 / 4, 8),
// The R8c4 layout is: [paddedOuter / 8, inner / P / 4, 32]
// The shape we use for TMA is: [paddedOuter / 8, inner / P / 4 / repeats, repeats * 32]
// where repeats = min(tileInner / P / 4, 8),
// where tileInner = tileN if batchM, otherwise tileM,
// where paddedOuter = paddedM if batchM, otherwise paddedN.
// where inner = N if batchM, otherwise M.
Expand Down Expand Up @@ -528,11 +551,13 @@ class BatchedGemmInterface {

auto [numCtaBatch, numCtaTile, numCtaInner] =
getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim);

auto kernelParams = KernelParamsSetup::setKernelParams(
options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB,
batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA,
batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA,
batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias,
batchedGemmData.mInputBuffers.mPtrPerTokenSfB,
batchedGemmData.mInputBuffers.mPtrSparsityInfoA, batchedGemmData.mInputBuffers.mPtrBias,
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
batchedGemmData.mInputBuffers.mPtrScaleAct, batchedGemmData.mInputBuffers.mPtrScaleGate,
batchedGemmData.mInputBuffers.mPtrClampLimit,
Expand All @@ -544,8 +569,7 @@ class BatchedGemmInterface {
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch);

// The size of the grid.
std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner}
: std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner};
auto grid = getLaunchGrid(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim);

BatchedGemmConfig batchedGemmConfig = config;
#ifndef TLLM_GEN_EXPORT_INTERFACE
Expand Down Expand Up @@ -676,8 +700,10 @@ class BatchedGemmInterface {
// For normal BMM, mNumTokens == 0 and the number of CTAs is known to host.
if (options.mIsStaticBatch) {
for (int32_t bi = 0; bi < options.mNumBatches; ++bi) {
numCtasBatch += batchM ? gemm::divUp(options.mBatchedM[bi], options.mTileM)
: gemm::divUp(options.mBatchedN[bi], options.mTileN);
numCtasBatch +=
batchM ? gemm::divUp(options.mBatchedM[bi], options.mTileM * options.mClusterDimX) *
options.mClusterDimX
: gemm::divUp(options.mBatchedN[bi], options.mTileN);
}
}
// For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime.
Expand Down Expand Up @@ -711,11 +737,24 @@ class BatchedGemmInterface {

//////////////////////////////////////////////////////////////////////////////////////////////////

// Returns the number of CTAs of the current kernel.
std::vector<int32_t> getLaunchGrid(
BatchedGemmOptions const& options,
std::optional<int32_t> maxNumCtasInBatchDim = std::nullopt) const {
auto [numCtaBatch, numCtaTile, numCtaInner] = getGridDim(options, maxNumCtasInBatchDim);
bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner}
: std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner};
return grid;
}

//////////////////////////////////////////////////////////////////////////////////////////////////

// Returns the number of CTAs of the current kernel.
int32_t getNumCtas(BatchedGemmOptions const& options,
std::optional<int32_t> maxNumCtasInBatchDim = std::nullopt) const {
auto [numCtasBatch, numCtasTile, numCtasInner] = getGridDim(options, maxNumCtasInBatchDim);
return numCtasBatch * numCtasTile * numCtasInner;
auto grid = getLaunchGrid(options, maxNumCtasInBatchDim);
return grid[0] * grid[1] * grid[2];
}

//////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -779,19 +818,20 @@ class BatchedGemmInterface {
auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
if (!options.mEnablesEarlyExit || options.mNumTokens == 0) {
for (int32_t bi = 0; bi < options.mNumBatches; ++bi) {
totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM)
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN);
totalNumPaddedTokens +=
batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX)
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN);
}
} else {
// Get tile in token dim.
auto tileTokensDim = batchM ? options.mTileM : options.mTileN;
auto tileTokensDim = batchM ? options.mTileM * options.mClusterDimX : options.mTileN;
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
}

// Get options from config.
auto& options = config.mOptions;

int const tokenTile = batchM ? options.mTileM : options.mTileN;
int const tokenTile = batchM ? options.mTileM * options.mClusterDimX : options.mTileN;

auto const numTokens = totalNumPaddedTokens;
auto const intermediateDim = batchM ? options.mN : options.mM;
Expand Down
Loading
Loading