diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index a19c89638d..ead183cfd2 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -299,10 +299,14 @@ __global__ void activationDeepSeekKernel(KernelParams params) { if (permutedIdx == -1) { continue; } - s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; + // Make sure the scale is strictly positive to avoid division by zero in case the + // maximum is zero. + float scaleOut = + fmaxf(aMaxArr[tokenInCtaIdx] / E4m3MaxVal, std::numeric_limits::min()); + s_scaleOutArr[tokenInCtaIdx] = scaleOut; int const scaleOut_idx = permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); - params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal; + params.outDqSfsPtr[scaleOut_idx] = scaleOut; } } __syncthreads(); diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 292a35e3b2..71addb0a15 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -89,7 +89,7 @@ class ArtifactPath: TRTLLM_GEN_FMHA: str = "75d477a640f268ea9ad117cc596eb39245713b9e/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" + "fea3b0ecfe11d7b34556042aeb5d0465ad101500/batched_gemm-332ffef-9936841" ) TRTLLM_GEN_GEMM: str = ( "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" @@ -110,7 +110,7 @@ class CheckSumHash: "e014d7a54c396733ef012b223603c1be2861019f88faa5dcc882ed1ecfe5c2d9" ) TRTLLM_GEN_BMM: str = ( - "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" + "1c3c7ae0755a0acb7ad35da7dbdb90ab71c253dc289051faa2b4e3180dfc4b23" ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index c4929c47b4..d99e2511d6 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -97,7 +97,7 @@ struct BatchedGemmData { // The matrix A. The data type is controlled by options.mDtypeA. // // If (routeAct == true && batchM), the shape is [M, K] - // Else + // Elseif (batchStrideInTokens > 0) // If batchM: // Logical shape is [sum(divUpMul(M[bi], tileM) for bi in B), K]. // Logical strides are [K, 1]. @@ -113,6 +113,14 @@ struct BatchedGemmData { // Logical shape is [B, K / blockK, divUpMul(M, tileM), blockK]. // Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, 1]. // where blockK is 128B. + // Else // batchStrideInTokens == 0 + // If batchM: + // Logical shape is [M, K]. + // Logical strides are [K, 1]. + // + // If batchN: + // Logical shape is [B, divUpMul(M, tileM), K]. + // Logical strides are [divUpMul(M, tileM) * K, K, 1]. void const* mPtrA{nullptr}; // The block scaling factors to dequantize A. @@ -160,7 +168,7 @@ struct BatchedGemmData { // // If (routeAct == true && batchN), the shape is [N, K] // - // Else + // Else if (batchStrideInTokens > 0) // If batchN: // Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B), K]. // Logical strides are [K, 1]. @@ -176,6 +184,15 @@ struct BatchedGemmData { // Logical shape is [B, K / blockK, divUpMul(N, tileN), blockK]. // Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN) * blockK, blockK, 1]. // where blockK is 128B. + // + // Else // batchStrideInTokens == 0 + // If batchN: + // Logical shape is [N, K]. + // Logical strides are [K, 1]. + // + // If batchM: + // Logical shape is [B, divUpMul(N, tileN), K]. + // Logical strides are [divUpMul(N, tileN) * K, K, 1]. void const* mPtrB{nullptr}; // The scaling factors to dequantize B. @@ -256,6 +273,13 @@ struct BatchedGemmData { // Shape is [B]. float const* mPtrScaleC{nullptr}; + // The pre-activation scaling factor (typically dequantA * dequantB) for non-gated non-linear + // activation. + // Only used when non-linear activation is applied (e.g., GELU, Relu2). + // When used, scaleC should be quantScaleC only, and this scale is applied before the + // activation. Shape is [B]. + float const* mPtrScaleAct{nullptr}; + // The output gate scale for Fp8 (not DeepSeek FP8) and NvFp4 quantization. // TensorRT-LLM API requires a scaling factor on the device. // scaleGate = dequantA * dequantB, @@ -478,7 +502,127 @@ class BatchedGemmInterface { int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /*multiProcessorCount*/, bool usePdl = true, - std::optional> moduleCache = std::nullopt); + std::optional> moduleCache = std::nullopt) { + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, batchedGemmData); + + bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; + bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && + options.mDtypeB == tg::Dtype::E4m3; + + auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); + float* dPtrRowMax{nullptr}; + uint32_t* dPtrRowMaxBars{nullptr}; + + // Set the completion barriers to 0 if needed. + if (useDeepSeekFp8 && options.mFusedAct) { + dPtrRowMax = reinterpret_cast(alignPtr(reinterpret_cast(workspace), 1024)); + dPtrRowMaxBars = reinterpret_cast( + alignPtr(reinterpret_cast(dPtrRowMax) + workspaceSizes[0], 1024)); + auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], + reinterpret_cast(cudaStream)); + if (err != cudaSuccess) { + return 1; + } + } + + auto [numCtaBatch, numCtaTile, numCtaInner] = + getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); + auto kernelParams = KernelParamsSetup::setKernelParams( + options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, + batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, + batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, + batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, + batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, + batchedGemmData.mInputBuffers.mPtrScaleAct, batchedGemmData.mInputBuffers.mPtrScaleGate, + batchedGemmData.mInputBuffers.mPtrClampLimit, + batchedGemmData.mInputBuffers.mPtrGatedActAlpha, + batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, + dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, + batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); + + // The size of the grid. + std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} + : std::vector{numCtaTile, numCtaBatch, numCtaInner}; + + BatchedGemmConfig batchedGemmConfig = config; +#ifndef TLLM_GEN_EXPORT_INTERFACE + // Generate and compile the kernel if data is not provided. + if (config.mData == nullptr) { + batchedGemmConfig = generateAndCompileKernel(batchedGemmConfig); + } + TLLM_CHECK_ERROR(batchedGemmConfig.mCudaRunner != nullptr, "CudaRunner is not set"); + batchedGemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid, + /* cluster */ {}, + /* instanceId */ batchedGemmConfig.mInstanceIdx); + return 0; +#endif + + CUmodule cuModule; + CUfunction cuFunction; + + if (moduleCache.has_value()) { + ModuleCache& moduleCacheRef = moduleCache.value().get(); + + // Modules are associated with a specific context, so the context is included in the key + CUcontext ctx; + unsigned long long ctxId; + cuCtxGetCurrent(&ctx); + cuCtxGetId(ctx, &ctxId); + + // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a + // string in decimal representation. + std::string const ctxName = + std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); + std::string const funcName = std::string(batchedGemmConfig.mFunctionName); + auto const moduleKey = ctxName + funcName; + auto module = moduleCacheRef.find(moduleKey); + + // Use cache if module is found, otherwise load and insert into cache + if (module != moduleCacheRef.end()) { + cuFunction = std::get<1>(module->second); + } else { + gemm::loadCubinData(&cuModule, batchedGemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); + } + } else { + gemm::loadCubinData(&cuModule, batchedGemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + } + + // Prepare the grid/block. + dim3 block3{static_cast(batchedGemmConfig.mNumThreadsPerCTA), + static_cast(1), static_cast(1)}; + dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), + (grid.size() > 1 ? static_cast(grid[1]) : 1u), + (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; + // Prepare the cluster size. + dim3 cluster3{static_cast(options.mClusterDimX), + static_cast(options.mClusterDimY), + static_cast(options.mClusterDimZ)}; + + // Whether PDL can safely be enabled + const bool pdlSafe = batchedGemmConfig.mOptions.mGridWaitForPrimaryRouting || + batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit || + batchedGemmConfig.mOptions.mGridWaitForPrimaryA || + batchedGemmConfig.mOptions.mGridWaitForPrimaryB; + + // Run the kernel. + auto result = trtllm::gen::launchKernel((void*)&kernelParams, cudaStream, + batchedGemmConfig.mSharedMemSize, cuFunction, block3, + grid3, cluster3, usePdl && pdlSafe); + if (result != CUDA_SUCCESS) { + return result; + } + // If a module cache has not been given, unload the module to avoid leaking + if (!moduleCache.has_value()) { + cuModuleUnload(cuModule); + } + return 0; + } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -683,130 +827,6 @@ class BatchedGemmInterface { int32_t mNumRotations; }; -int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, - BatchedGemmData const& batchedGemmData, void* cudaStream, - int32_t /*multiProcessorCount*/, bool usePdl, - std::optional> moduleCache) { - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, batchedGemmData); - - bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; - bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && - options.mDtypeB == tg::Dtype::E4m3; - - auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); - float* dPtrRowMax{nullptr}; - uint32_t* dPtrRowMaxBars{nullptr}; - - // Set the completion barriers to 0 if needed. - if (useDeepSeekFp8 && options.mFusedAct) { - dPtrRowMax = reinterpret_cast(alignPtr(reinterpret_cast(workspace), 1024)); - dPtrRowMaxBars = reinterpret_cast( - alignPtr(reinterpret_cast(dPtrRowMax) + workspaceSizes[0], 1024)); - auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], - reinterpret_cast(cudaStream)); - if (err != cudaSuccess) { - return 1; - } - } - - auto [numCtaBatch, numCtaTile, numCtaInner] = - getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); - auto kernelParams = KernelParamsSetup::setKernelParams( - options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, - batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, - batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, - batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, - batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, - batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, - batchedGemmData.mInputBuffers.mPtrGatedActAlpha, - batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, - dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, - batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); - - // The size of the grid. - std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} - : std::vector{numCtaTile, numCtaBatch, numCtaInner}; - - BatchedGemmConfig batchedGemmConfig = config; -#ifndef TLLM_GEN_EXPORT_INTERFACE - // Generate and compile the kernel if data is not provided. - if (config.mData == nullptr) { - batchedGemmConfig = generateAndCompileKernel(batchedGemmConfig); - } - TLLM_CHECK_ERROR(batchedGemmConfig.mCudaRunner != nullptr, "CudaRunner is not set"); - batchedGemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid, - /* cluster */ {}, - /* instanceId */ batchedGemmConfig.mInstanceIdx); - return 0; -#endif - - CUmodule cuModule; - CUfunction cuFunction; - - if (moduleCache.has_value()) { - ModuleCache& moduleCacheRef = moduleCache.value().get(); - - // Modules are associated with a specific context, so the context is included in the key - CUcontext ctx; - unsigned long long ctxId; - cuCtxGetCurrent(&ctx); - cuCtxGetId(ctx, &ctxId); - - // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a - // string in decimal representation. - std::string const ctxName = - std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); - std::string const funcName = std::string(batchedGemmConfig.mFunctionName); - auto const moduleKey = ctxName + funcName; - auto module = moduleCacheRef.find(moduleKey); - - // Use cache if module is found, otherwise load and insert into cache - if (module != moduleCacheRef.end()) { - cuFunction = std::get<1>(module->second); - } else { - gemm::loadCubinData(&cuModule, batchedGemmConfig); - cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); - moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); - } - } else { - gemm::loadCubinData(&cuModule, batchedGemmConfig); - cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); - } - - // Prepare the grid/block. - dim3 block3{static_cast(batchedGemmConfig.mNumThreadsPerCTA), static_cast(1), - static_cast(1)}; - dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), - (grid.size() > 1 ? static_cast(grid[1]) : 1u), - (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; - // Prepare the cluster size. - dim3 cluster3{static_cast(options.mClusterDimX), - static_cast(options.mClusterDimY), - static_cast(options.mClusterDimZ)}; - - // Whether PDL can safely be enabled - const bool pdlSafe = batchedGemmConfig.mOptions.mGridWaitForPrimaryRouting || - batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit || - batchedGemmConfig.mOptions.mGridWaitForPrimaryA || - batchedGemmConfig.mOptions.mGridWaitForPrimaryB; - - // Run the kernel. - auto result = - trtllm::gen::launchKernel((void*)&kernelParams, cudaStream, batchedGemmConfig.mSharedMemSize, - cuFunction, block3, grid3, cluster3, usePdl && pdlSafe); - if (result != CUDA_SUCCESS) { - return result; - } - // If a module cache has not been given, unload the module to avoid leaking - if (!moduleCache.has_value()) { - cuModuleUnload(cuModule); - } - return 0; -} - //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace batchedGemm diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index f3e73a5aac..acbc01497c 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -16,6 +16,7 @@ */ #pragma once +#include #include #include @@ -107,10 +108,12 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { // GemmGatedActOptions gemmGatedAct::ActType actType, bool clampBeforeAct, // BatchedGemmOptions - std::vector batchedM, std::vector batchedN, BatchMode batchMode, bool fusedAct, - bool gridWaitForPrimaryRouting, bool isStaticBatch, int numBatches, int numRegsPerThreadLoadB, - int numRegsPerThreadLoadSfB, int numTokens, int numWarpsLoadB, int numWarpsLoadSfB, - RouteImpl routeImpl, std::optional routeSfsImpl, bool useTmaOobOpt) + std::vector batchedM, std::vector batchedN, BatchMode batchMode, + int32_t batchStrideInTokens, bool fusedAct, bool gridWaitForPrimaryRouting, + bool isStaticBatch, bool isUniformNumTokensPerBatch, int numBatches, + int numRegsPerThreadLoadB, int numRegsPerThreadLoadSfB, int numTokens, int numWarpsLoadB, + int numWarpsLoadSfB, RouteImpl routeImpl, std::optional routeSfsImpl, + bool useTmaOobOpt) : gemmGatedAct::GemmGatedActOptions( gemm::GemmOptions( allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, @@ -134,9 +137,11 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { mBatchedM(batchedM), mBatchedN(batchedN), mBatchMode(BatchMode(batchMode)), + mBatchStrideInTokens(batchStrideInTokens), mFusedAct(fusedAct), mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting), mIsStaticBatch(isStaticBatch), + mIsUniformNumTokensPerBatch(isUniformNumTokensPerBatch), mNumBatches(numBatches), mNumRegsPerThreadLoadB{numRegsPerThreadLoadB}, mNumRegsPerThreadLoadSfB{numRegsPerThreadLoadSfB}, @@ -153,6 +158,8 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { std::vector mBatchedN; // Whether batching M or N. BatchMode mBatchMode{BatchMode::BatchM}; + // Stride between batches in tokens dimension for input matrix. + int32_t mBatchStrideInTokens{-1}; // Whether to perform a fused gated activation. bool mFusedAct{false}; // Whether the loads that load from ptrRouteMap, ptrTotalNumPaddedTokens, @@ -160,6 +167,8 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { bool mGridWaitForPrimaryRouting{true}; // Whether the batch size is static (i.e. known at kernel launch time). bool mIsStaticBatch{true}; + // Whether the number of tokens in each entry of the batch is the same. + bool mIsUniformNumTokensPerBatch{false}; // Number of Gemm batches. int mNumBatches; // Number of registers per thread for load B @@ -366,6 +375,58 @@ inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, tg::Cu "change the input routing data layout to be padded to clusterDimX size."); } + // Check if all elements in mBatchedM or mBatchedN are the same (uniform tokens per batch) and + // set mIsUniformNumTokensPerBatch and mBatchStride. + if (options.mIsUniformNumTokensPerBatch) { + int32_t firstValue = 0; + bool isUniformNumTokensPerBatch = false; + if (batchM && !options.mBatchedM.empty()) { + firstValue = options.mBatchedM[0]; + isUniformNumTokensPerBatch = std::all_of(options.mBatchedM.begin(), options.mBatchedM.end(), + [firstValue](int32_t v) { return v == firstValue; }); + } else if (!batchM && !options.mBatchedN.empty()) { + firstValue = options.mBatchedN[0]; + isUniformNumTokensPerBatch = std::all_of(options.mBatchedN.begin(), options.mBatchedN.end(), + [firstValue](int32_t v) { return v == firstValue; }); + } else { + TLLM_CHECK_ERROR( + false, "mBatchedM or mBatchedN must be specified when using uniform tokens per batch."); + } + auto tileTokensDim = batchM ? options.mTileM : options.mTileN; + TLLM_CHECK_ERROR(isUniformNumTokensPerBatch, + "All elements in mBatchedM or mBatchedN must be the same when using uniform " + "tokens per batch."); + TLLM_CHECK_ERROR(options.mBatchStrideInTokens >= 0, + "Batch stride in tokens must be greater or equal to 0 when using uniform " + "tokens per batch."); + TLLM_CHECK_ERROR_FMT( + options.mBatchStrideInTokens == 0 || + options.mBatchStrideInTokens == gemm::divUpMul(firstValue, tileTokensDim), + "Batch stride in tokens must be a 0 or a multiple of %s {%d} when using " + "uniform tokens per batch.", + batchM ? "TileM" : "TileN", tileTokensDim); + TLLM_CHECK_ERROR( + !options.mUseDeepSeekFp8, + "Uniform number of tokens per batch is not supported when using DeepSeek Fp8."); + TLLM_CHECK_ERROR( + !options.mUsePerTokenSfA && !options.mUsePerTokenSfB, + "Uniform number of tokens per batch is not supported when using per-token SF."); + TLLM_CHECK_ERROR(options.mBiasType == gemm::BiasType::None, + "Uniform number of tokens per batch is not supported when using bias."); + TLLM_CHECK_ERROR(options.mRouteImpl == RouteImpl::NoRoute, + "Uniform number of tokens per batch is not supported when using routing."); + TLLM_CHECK_ERROR( + !options.mFusedAct, + "Uniform number of tokens per batch is not supported when using fused gated activation."); + TLLM_CHECK_ERROR(!tg::dtypeIsBlockFmt(options.mDtypeA) && + !tg::dtypeIsBlockFmt(options.mDtypeB) && + !tg::dtypeIsBlockFmt(options.mDtypeC), + "Uniform number of tokens per batch is not supported when using block " + "format for dtypeA, dtypeB, or dtypeC."); + } else if (options.mBatchStrideInTokens >= 0) { + TLLM_LOG_WARNING("Batch stride in tokens is set to ", options.mBatchStrideInTokens, + " but it is not used when not using uniform tokens per batch."); + } return isValid; } @@ -404,9 +465,13 @@ inline std::string dumpOptions(BatchedGemmOptions const& options, bool dumpRunti } ss << "mBatchMode=batchedGemm::BatchedGemmOptions::BatchMode(" << static_cast(options.mBatchMode) << ")," << std::endl; + if (dumpRuntimeParams) { + ss << "mBatchStrideInTokens=" << options.mBatchStrideInTokens << "," << std::endl; + } ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; ss << "mIsStaticBatch=" << options.mIsStaticBatch << "," << std::endl; + ss << "mIsUniformNumTokensPerBatch=" << options.mIsUniformNumTokensPerBatch << "," << std::endl; if (dumpRuntimeParams) { ss << "mNumBatches=" << options.mNumBatches << "," << std::endl; } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h index 2a1c371ad8..f0b63e674e 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h @@ -90,6 +90,12 @@ enum class BiasType : uint32_t { // Type of the element-wise activation to apply after the Gemm enum class EltwiseActType { None = 0, + // Gelu is defined as the following operation: + // act = x0 * phi(x0) + // where x0 is the output of the Gemm + // phi is the CDF of standard normal distribution approximated by + // phi(x) = 0.5 * (1 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x))) + Gelu, // Relu2 (also known as squared Relu) is defined as the following operation: // act = relu(x0) ^ 2 // where x0 is the output of the Gemm. diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h index 9fb4a010a4..b3d955dc7c 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h @@ -185,8 +185,8 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& inline std::string dumpOptions(GemmGatedActOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; ss << gemm::dumpOptions(options, dumpRuntimeParams) << ", "; - ss << "mActType=" - << "gemmGatedAct::ActType(" << static_cast(options.mActType) << ")," << std::endl; + ss << "mActType=" << "gemmGatedAct::ActType(" << static_cast(options.mActType) << ")," + << std::endl; ss << "mClampBeforeAct=" << options.mClampBeforeAct << "" << std::endl; return ss.str(); } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index 54daac4a8d..37da75c10d 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -446,40 +446,30 @@ inline std::string toString(trtllm::gen::MmaKind e) { inline std::string dumpOptions(GemmOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; - ss << "mAllReduceAlgo=" - << "gemm::AllReduceAlgo(" << static_cast(options.mAllReduceAlgo) << ")" - << "," << std::endl; - ss << "mBiasType=" - << "gemm::BiasType(" << static_cast(options.mBiasType) << ")" - << "," << std::endl; + ss << "mAllReduceAlgo=" << "gemm::AllReduceAlgo(" << static_cast(options.mAllReduceAlgo) + << ")" << "," << std::endl; + ss << "mBiasType=" << "gemm::BiasType(" << static_cast(options.mBiasType) << ")" << "," + << std::endl; ss << "mBlockK=" << options.mBlockK << "," << std::endl; ss << "mClusterDimX=" << options.mClusterDimX << "," << std::endl; ss << "mClusterDimY=" << options.mClusterDimY << "," << std::endl; ss << "mClusterDimZ=" << options.mClusterDimZ << "," << std::endl; - ss << "mCtaSwizzleType=" - << "gemm::CtaSwizzleType(" << static_cast(options.mCtaSwizzleType) << ")" - << "," << std::endl; - ss << "mDtypeAcc=" - << "trtllm::gen::Dtype(" << static_cast(options.mDtypeAcc) << ")" - << "," << std::endl; - ss << "mDtypeA=" - << "trtllm::gen::Dtype(" << static_cast(options.mDtypeA) << ")" + ss << "mCtaSwizzleType=" << "gemm::CtaSwizzleType(" + << static_cast(options.mCtaSwizzleType) << ")" << "," << std::endl; + ss << "mDtypeAcc=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeAcc) << ")" << "," << std::endl; - ss << "mDtypeB=" - << "trtllm::gen::Dtype(" << static_cast(options.mDtypeB) << ")" - << "," << std::endl; - ss << "mDtypeC=" - << "trtllm::gen::Dtype(" << static_cast(options.mDtypeC) << ")" - << "," << std::endl; - ss << "mDtypeMmaA=" - << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaA) << ")" - << "," << std::endl; - ss << "mDtypeMmaB=" - << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaB) << ")" + ss << "mDtypeA=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeA) << ")" << "," + << std::endl; + ss << "mDtypeB=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeB) << ")" << "," + << std::endl; + ss << "mDtypeC=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeC) << ")" << "," + << std::endl; + ss << "mDtypeMmaA=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaA) << ")" << "," << std::endl; - ss << "mEltwiseActType=" - << "gemm::EltwiseActType(" << static_cast(options.mEltwiseActType) << ")" + ss << "mDtypeMmaB=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeMmaB) << ")" << "," << std::endl; + ss << "mEltwiseActType=" << "gemm::EltwiseActType(" + << static_cast(options.mEltwiseActType) << ")" << "," << std::endl; ss << "mEnablesEarlyExit=" << options.mEnablesEarlyExit << "," << std::endl; ss << "mEnablesDelayedEarlyExit=" << options.mEnablesDelayedEarlyExit << "," << std::endl; ss << "mEnablesGlobalPtxKnobs=" << options.mEnablesGlobalPtxKnobs << "," << std::endl; @@ -498,18 +488,16 @@ inline std::string dumpOptions(GemmOptions const& options, bool dumpRuntimeParam if (dumpRuntimeParams) { ss << "mK=" << options.mK << "," << std::endl; } - ss << "mKernelTraits={}" - << "," << std::endl; - ss << "mLayoutA=gemm::MatrixLayout(" << static_cast(options.mLayoutA) << ")" - << "," << std::endl; - ss << "mLayoutB=gemm::MatrixLayout(" << static_cast(options.mLayoutB) << ")" - << "," << std::endl; + ss << "mKernelTraits={}" << "," << std::endl; + ss << "mLayoutA=gemm::MatrixLayout(" << static_cast(options.mLayoutA) << ")" << "," + << std::endl; + ss << "mLayoutB=gemm::MatrixLayout(" << static_cast(options.mLayoutB) << ")" << "," + << std::endl; if (dumpRuntimeParams) { ss << "mM=" << options.mM << "," << std::endl; } ss << "mMmaK=" << options.mMmaK << "," << std::endl; - ss << "mMmaKind=" - << "trtllm::gen::MmaKind(" << static_cast(options.mMmaKind) << ")" + ss << "mMmaKind=" << "trtllm::gen::MmaKind(" << static_cast(options.mMmaKind) << ")" << "," << std::endl; ss << "mMmaM=" << options.mMmaM << "," << std::endl; ss << "mMmaN=" << options.mMmaN << "," << std::endl; @@ -536,30 +524,23 @@ inline std::string dumpOptions(GemmOptions const& options, bool dumpRuntimeParam if (options.mSfBlockSizeA.has_value()) { ss << "mSfBlockSizeA=" << options.mSfBlockSizeA.value() << "," << std::endl; } else { - ss << "mSfBlockSizeA=" - << "std::nullopt" - << ", " << std::endl; + ss << "mSfBlockSizeA=" << "std::nullopt" << ", " << std::endl; } - ss << "mSfLayoutA=" - << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutA) << ")" + ss << "mSfLayoutA=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutA) << ")" << "," << std::endl; - ss << "mSfLayoutB=" - << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutB) << ")" + ss << "mSfLayoutB=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutB) << ")" << "," << std::endl; - ss << "mSfLayoutC=" - << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutC) << ")" + ss << "mSfLayoutC=" << "trtllm::gen::SfLayout(" << static_cast(options.mSfLayoutC) << ")" << "," << std::endl; ss << "mSfReshapeFactor=" << options.mSfReshapeFactor << "," << std::endl; ss << "mSliceK=" << options.mSliceK << "," << std::endl; - ss << "mSplitK=" - << "gemm::SplitK(" << static_cast(options.mSplitK) << ")" - << "," << std::endl; + ss << "mSplitK=" << "gemm::SplitK(" << static_cast(options.mSplitK) << ")" << "," + << std::endl; ss << "mTileK=" << options.mTileK << "," << std::endl; ss << "mTileM=" << options.mTileM << "," << std::endl; ss << "mTileN=" << options.mTileN << "," << std::endl; - ss << "mTileScheduler=" - << "gemm::TileScheduler(" << static_cast(options.mTileScheduler) << ")" - << "," << std::endl; + ss << "mTileScheduler=" << "gemm::TileScheduler(" << static_cast(options.mTileScheduler) + << ")" << "," << std::endl; ss << "mTransposeMmaOutput=" << options.mTransposeMmaOutput << "," << std::endl; ss << "mUseCustomMmaSchedule=" << options.mUseCustomMmaSchedule << "," << std::endl; ss << "mUseDeepSeekFp8=" << options.mUseDeepSeekFp8 << "," << std::endl; @@ -770,16 +751,19 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, tg::CudaArch cudaArc } } - if ((options.mMmaKind == tg::MmaKind::Fp8Fp6Fp4 || - options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) && - options.mMmaK != 32) { - TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, - ") for MmaKind=", gemm::toString(options.mMmaKind), ". Setting MmaK to 32"); - if (updateOptions) { - options.mMmaK = 32; - options.mTileK = std::max(options.mMmaK, options.mTileK); - } else { - return false; + if (options.mMmaKind == tg::MmaKind::Fp8Fp6Fp4) { + int mmaK = 32; + + if (options.mMmaK != mmaK) { + TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, + ") for MmaKind=", gemm::toString(options.mMmaKind), ". Setting MmaK to ", + mmaK); + if (updateOptions) { + options.mMmaK = mmaK; + options.mTileK = std::max(options.mMmaK, options.mTileK); + } else { + return false; + } } } @@ -807,7 +791,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, tg::CudaArch cudaArc options.mEpilogueLdtmDps, "dp", options.mEpilogueLdtmBits, "bit."); } - // Constraints for NvFp4 and MxFp8. + // Constraints for NvFp4, MxFp8, and MxFp4. if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 || options.mDtypeC == tg::Dtype::MxE4m3) && options.mMmaM != 128) { @@ -837,12 +821,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, tg::CudaArch cudaArc int mmaK = 32; if (options.mMmaKind == tg::MmaKind::MxFp4NvFp4) { + mmaK = 64; if (options.mMmaK == 96) { mmaK = 96; TLLM_CHECK_ERROR(options.mTileK == 768, "When mmaK == 96, only tileK == 768 is supported"); TLLM_CHECK_ERROR(options.mTileN <= 128, "When mmaK == 96, only tileN <= 128 is supported"); - } else { - mmaK = 64; } } if (options.mMmaK != mmaK) { @@ -1493,6 +1476,15 @@ inline bool getDoesScaleAb(tg::Dtype dtypeA, tg::Dtype dtypeB, bool useDeepSeekF return doesScaleAb; } +////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getDoesScaleAct(tg::Dtype dtypeA, tg::Dtype dtypeB, bool useDeepSeekFp8, + EltwiseActType eltwiseActType) { + // Only non-linear activations require separate scaleAct. + bool const isLinearAct = eltwiseActType == EltwiseActType::None; + return !isLinearAct && getDoesScaleAb(dtypeA, dtypeB, useDeepSeekFp8); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// inline bool getKernelDoesScaleC(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index 8094f1490e..ab004a9691 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -299,9 +299,9 @@ template static KernelParams setKernelParams( GemmOptions_ const& options, bool const batchM, void const* ptrA, void const* ptrB, void* ptrC, void const* dSfA, void const* dSfB, void const* ptrPerTokenSfA, void const* ptrPerTokenSfB, - void const* ptrBias, void* dSfC, float const* ptrScaleC, float const* ptrScaleGate, - float const* ptrClampLimit, float const* ptrGatedActAlpha, float const* ptrGatedActBeta, - int32_t const* routeMap, float* rowMax, uint32_t* rowMaxBars, + void const* ptrBias, void* dSfC, float const* ptrScaleC, float const* ptrScaleAct, + float const* ptrScaleGate, float const* ptrClampLimit, float const* ptrGatedActAlpha, + float const* ptrGatedActBeta, int32_t const* routeMap, float* rowMax, uint32_t* rowMaxBars, int32_t const* ptrNumNonExitingCtas = nullptr, int32_t const* ptrTotalNumPaddedTokens = nullptr, int32_t const* ptrCtaIdxXyToBatchIdx = nullptr, int32_t const* ptrCtaIdxXyToMnLimit = nullptr, int32_t const maxNumCtas = KernelParams::MaxNumCtas) { @@ -315,6 +315,7 @@ static KernelParams setKernelParams( params.numTokens = options.mNumTokens; params.ptrScaleC = ptrScaleC; + params.ptrScaleAct = ptrScaleAct; params.ptrScaleGate = ptrScaleGate; params.ptrClampLimit = ptrClampLimit; params.ptrGatedActAlpha = ptrGatedActAlpha; @@ -326,7 +327,7 @@ static KernelParams setKernelParams( // known at kernel launch time. Otherwise, these parameters are defined in the device buffers: // ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx and ptrCtaIdxXyToMnLimit respectively. - if (options.mIsStaticBatch) { + if (options.mIsStaticBatch && !options.mIsUniformNumTokensPerBatch) { params.totalNumPaddedTokens = 0; for (int b = 0; b < options.mNumBatches; b++) { int mM = batchM ? options.mBatchedM[b] : options.mM; @@ -362,6 +363,21 @@ static KernelParams setKernelParams( params.totalNumPaddedTokens += numCtas * tile; } + params.totalNumOutputPaddedTokens = params.totalNumPaddedTokens; + } else if (options.mIsStaticBatch && options.mIsUniformNumTokensPerBatch) { + auto numTokens = batchM ? options.mBatchedM[0] : options.mBatchedN[0]; + auto tileTokensDim = batchM ? options.mTileM : options.mTileN; + params.batchStrideInCtas = (options.mBatchStrideInTokens + tileTokensDim - 1) / tileTokensDim; + params.ctasInTokenDimPerBatch = (numTokens + tileTokensDim - 1) / tileTokensDim; + params.totalNumOutputPaddedTokens = + params.ctasInTokenDimPerBatch * tileTokensDim * options.mNumBatches; + if (params.batchStrideInCtas == 0) { + params.totalNumPaddedTokens = params.ctasInTokenDimPerBatch * tileTokensDim; + } else { + params.totalNumPaddedTokens = + params.ctasInTokenDimPerBatch * tileTokensDim * options.mNumBatches; + } + ctaOffset = maxNumCtas; } else { params.ptrTotalNumPaddedTokens = ptrTotalNumPaddedTokens; params.ptrCtaIdxXyToBatchIdx = ptrCtaIdxXyToBatchIdx; @@ -382,6 +398,14 @@ static KernelParams setKernelParams( params.ptrSfB = dSfB; params.ptrSfC = dSfC; + // Do we pad A? + bool doPadA = + options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 && options.mDtypeA == tg::Dtype::MxE2m1; + + // Do we pad B? + bool doPadB = + options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 && options.mDtypeB == tg::Dtype::MxE2m1; + if (!batchM) { // A is the expert if (0 != options.mM % options.mTileM) { @@ -394,8 +418,9 @@ static KernelParams setKernelParams( options, options.mM, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixA, options.mValidM, options.mValidN, options.mValidK); // Build tma descriptor for A. - params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, strideA, - tileShapeA, const_cast(ptrA)); + params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, shapeA, strideA, tileShapeA, + const_cast(ptrA), doPadA, + /*doSwizzle=*/true); // The input is padded: // [act0, padding, padding, ... TileN size .., act1, padding, padding, ...] @@ -410,8 +435,9 @@ static KernelParams setKernelParams( options.mTileM, (useRouteAct ? 1 : options.mTileN), options.mTileK, MatrixType::MatrixB, options.mValidM, useRouteAct ? options.mNumTokens : inputNumTokens, options.mValidK); // Build tma descriptor for B. - params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, - strideB, tileShapeB, const_cast(ptrB)); + params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, shapeB, strideB, tileShapeB, + const_cast(ptrB), doPadB, + /*doSwizzle=*/true); } if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || @@ -458,9 +484,10 @@ static KernelParams setKernelParams( options, options.mM, options.mNumTokens, numSfsInK, options.mTileM, 1 /* tileN */, options.mTileK / numEltsPerSf, MatrixType::MatrixB, options.mValidM, options.mNumTokens, numSfsInValidK); - params.tmaSfB[0] = gemm::buildNdTmaDescriptor( - dTypeSf, options.mMmaKind, shapeSfB, strideSfB, tileShapesSfB, const_cast(dSfB), - /*doSwizzle*/ true); + params.tmaSfB[0] = gemm::buildNdTmaDescriptor(dTypeSf, shapeSfB, strideSfB, tileShapesSfB, + const_cast(dSfB), + /*doPad=*/false, + /*doSwizzle=*/true); } else if (batchedGemm::doesRouteImplUseNoRoute(options.mRouteSfsImpl.value())) { // The input is padded: // [act0, padding, padding, ... TileN size .., act1, padding, padding, ...] @@ -487,8 +514,9 @@ static KernelParams setKernelParams( options, options.mM, ctaOffset * options.mTileN, options.mK, options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC); // Build tma descriptor for C. - params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, - strideC, tileShapeC, ptrC); + params.tmaC[0] = + gemm::buildNdTmaDescriptor(options.mDtypeC, shapeC, strideC, tileShapeC, ptrC, + /*doPad=*/false); } else { params.ptrC = ptrC; @@ -506,8 +534,9 @@ static KernelParams setKernelParams( options, options.mM, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixB, options.mValidM, options.mValidN, options.mValidK); // Build tma descriptor for B. - params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, strideB, - tileShapeB, const_cast(ptrB)); + params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, shapeB, strideB, tileShapeB, + const_cast(ptrB), doPadB, + /*doSwizzle=*/true); if (options.mRouteImpl == batchedGemm::RouteImpl::NoRoute) { // A is the activation @@ -519,8 +548,9 @@ static KernelParams setKernelParams( options, inputNumTokens, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixA, inputNumTokens, options.mValidN, options.mValidK); // Build tma descriptor for A. - params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, - strideA, tileShapeA, const_cast(ptrA)); + params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, shapeA, strideA, tileShapeA, + const_cast(ptrA), doPadA, + /*doSwizzle=*/true); } if (options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxE4m3 || @@ -567,8 +597,9 @@ static KernelParams setKernelParams( options, ctaOffset * options.mTileM, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC); // Build tma descriptor for C. - params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, - strideC, tileShapeC, ptrC); + params.tmaC[0] = + gemm::buildNdTmaDescriptor(options.mDtypeC, shapeC, strideC, tileShapeC, ptrC, + /*doPad=*/false); } else { params.ptrC = ptrC; } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h index e11374739f..b1e5dba024 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h @@ -35,10 +35,16 @@ struct KernelParams { // makeTmaShapeStrideAbc. // // If batchM: - // Logical shape is [sum(divUpMul(M[bi], tileM) for bi in B), K]. - // Logical strides are [K, 1]. - // Tile box shape is [tileM, tileK]. - // Tile box strides are [tileK, 1]. + // If batchStrideInTokens > 0: + // Logical shape is [sum(divUpMul(M[bi], tileM) for bi in B), K]. + // Logical strides are [K, 1]. + // Tile box shape is [tileM, tileK]. + // Tile box strides are [tileK, 1]. + // Else // batchStrideInTokens == 0: + // Logical shape is [M, K]. + // Logical strides are [K, 1]. + // Tile box shape is [tileM, tileK]. + // Tile box strides are [tileK, 1]. // // If batchN: // If layoutA is MatrixLayout::MajorK @@ -84,10 +90,16 @@ struct KernelParams { // where blockK is 128B. // // If batchN: - // Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B), K]. - // Logical strides are [K, 1]. - // Tile box shape is [tileN, tileK]. - // Tile box strides are [tileK, 1]. + // If batchStrideInTokens > 0: + // Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B), K]. + // Logical strides are [K, 1]. + // Tile box shape is [tileN, tileK]. + // Tile box strides are [tileK, 1]. + // Else // batchStrideInTokens == 0: + // Logical shape is [N, K]. + // Logical strides are [K, 1]. + // Tile box shape is [tileN, tileK]. + // Tile box strides are [tileK, 1]. // // Dtype is set from options.mDtypeB. CUtensorMap tmaB[1]; @@ -195,6 +207,13 @@ struct KernelParams { // Shape is [B]. One scaling factor per tensor in batch. float const* ptrScaleC{nullptr}; + // The pre-activation scaling factor (typically dequantA * dequantB) for non-gated non-linear + // activation. + // Only used when non-linear activation is applied (e.g., GELU, Relu2). + // When used, scaleC should be quantScaleC only, and this scale is applied before the + // activation. Shape is [B]. + float const* ptrScaleAct{nullptr}; + // The output gate scale for MxFp{4,8}, Fp8, NvFp4 and DeepSeek FP8 quantization. // TensorRT-LLM API requires a scaling factor on the device. // Shape is [B]. One scaling factor per tensor in batch. @@ -444,6 +463,10 @@ struct KernelParams { // If isStaticBatch == true, totalNumPaddedTokens is used, otherwise ptrTotalNumPaddedTokens. int32_t totalNumPaddedTokens; + // Total number of padded tokens - used as the stride for the output activation + // and C scaling factors. This is only used when isUniformNumTokensPerBatch is true. + int32_t totalNumOutputPaddedTokens; + // A map from CTA index X/Y to batch index. // Check ptrCtaIdxXyToBatchIdx to see how it is computed. // If isStaticBatch == true, ctaIdxXyToBatchIdx is used, otherwise ptrCtaIdxXyToBatchIdx. @@ -455,6 +478,14 @@ struct KernelParams { // If isStaticBatch == true, ctaIdxXyToMnLimit is used, otherwise ptrCtaIdxXyToMnLimit. int32_t ctaIdxXyToMnLimit[MaxNumCtas]; + // Total number of CTAs in the token dimension per batch. + // Used only when isUniformNumTokensPerBatch is true. + int32_t ctasInTokenDimPerBatch{0}; + + // Stride for the batched dimension in the number of CTAs. + // Used only when isUniformNumTokensPerBatch is true. + int32_t batchStrideInCtas{0}; + ////////////////////////////////////////////////////////////////////////////////////////////////// // // All-reduce parameters. diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index d7b0b6b62f..024be709e1 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -138,7 +138,7 @@ class MemAllocatorHelper { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind) { +inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, int mmaK) { if (mmaKind == tg::MmaKind::Auto) { throw std::runtime_error("mmaKind != tg::MmaKind::Auto"); } @@ -168,7 +168,8 @@ class KernelTraits { bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) : mMmaKind{mmaKind}, mFuseUtccpWithUtcmma{fuseUtccpWithUtcmma}, - mUseMaxTmemOverlap{useMaxTmemOverlap} { + mUseMaxTmemOverlap{useMaxTmemOverlap}, + mNumEpilogueWarps{numEpilogueWarps} { // // SMEM // @@ -202,7 +203,7 @@ class KernelTraits { { // Number of bytes in load A shared memory. auto const numSmemBytesLoadA = - numStages * tileM * tileK * getNumSmemBitsPerElt(dtypeA, mMmaKind) / 8 /* bits */; + numStages * tileM * tileK * getNumSmemBitsPerElt(dtypeA, mMmaKind, mmaK) / 8 /* bits */; // Number of bytes for load A alignment for TMA load. auto const numBytesAlignmentLoadA = 1024; // loadA is already at first chunk. No need to reuse it. @@ -218,7 +219,7 @@ class KernelTraits { { // Number of bytes in load B shared memory. auto const numSmemBytesLoadB = numStages * (useTwoCtas ? tileN / 2 : tileN) * tileK * - getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; + getNumSmemBitsPerElt(dtypeB, mMmaKind, mmaK) / 8 /* bits */; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. @@ -238,9 +239,9 @@ class KernelTraits { { // Number of bytes in save shuffled B in shared memory. auto const numSmemBytesLoadB = - numSlicesForSliceK > 1 - ? numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */ - : 0; + numSlicesForSliceK > 1 ? numStages * tileN * tileK * + getNumSmemBitsPerElt(dtypeB, mMmaKind, mmaK) / 8 /* bits */ + : 0; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. @@ -474,13 +475,19 @@ class KernelTraits { bool const useBlockScalingA = tg::dtypeIsBlockFmt(dtypeMmaA); // Are the block scales constant? bool const useConstSfA = useBlockScalingA && !tg::dtypeIsBlockFmt(dtypeA); + // Number elements per scaling factor. + int32_t const numEltsPerSf = useBlockScalingA ? tg::dtypeNumEltsPerSf(dtypeMmaA) : -1; + // TMEM cols group size in the K dimension. + int32_t kGroupSize = 4; + // Number of columns per stage. + int32_t const numColsPerStage = + useBlockScalingA ? ((tileK / (kGroupSize * numEltsPerSf)) * + tg::getTmemColStridePerGroup(tileM, mmaK, kGroupSize)) + : 0; // Number of columns for scaling factors of A. auto const numTmemColsSfA = - useConstSfA - ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) - : (useBlockScalingA ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * - (mFuseUtccpWithUtcmma ? 1 : numStages) - : 0); + useConstSfA ? tg::roundUp(numColsPerStage, 4) + : (numColsPerStage * (mFuseUtccpWithUtcmma ? 1 : numStages)); // Number of columns for Sf alignment. auto const numColsAlignmentSfA = 4; // No need to reuse TMEM. @@ -499,13 +506,19 @@ class KernelTraits { bool const useBlockScalingB = tg::dtypeIsBlockFmt(dtypeMmaB); // Are the block scales constant? bool const useConstSfB = useBlockScalingB && !tg::dtypeIsBlockFmt(dtypeB); + // Number elements per scaling factor. + int32_t const numEltsPerSf = useBlockScalingB ? tg::dtypeNumEltsPerSf(dtypeMmaB) : -1; + // TMEM cols group size in the K dimension. + int32_t kGroupSize = 4; + // Number of columns per stage. + int32_t const numColsPerStage = + useBlockScalingB ? ((tileK / (kGroupSize * numEltsPerSf)) * + tg::getTmemColStridePerGroup(tileN, mmaK, kGroupSize)) + : 0; // Number of columns for scaling factors of B. auto const numTmemColsSfB = - useConstSfB - ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) - : (useBlockScalingB ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * - (mFuseUtccpWithUtcmma ? 1 : numStages) - : 0); + useConstSfB ? tg::roundUp(numColsPerStage, 4) + : (numColsPerStage * (mFuseUtccpWithUtcmma ? 1 : numStages)); // Number of columns for Sf alignment. auto const numColsAlignmentSfB = 4; // No need to reuse TMEM. @@ -531,6 +544,8 @@ class KernelTraits { bool mFuseUtccpWithUtcmma; // Whether use the max TMEM overlap trick. bool mUseMaxTmemOverlap; + // The number of epilogue warps. + int32_t mNumEpilogueWarps; // Helper for SMEM allocation. MemAllocatorHelper mSmemAllocatorHelper; // Helper for TMEM allocation. diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h index fa250f8fe9..9fff10deff 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h @@ -37,11 +37,10 @@ namespace tg = trtllm::gen; #ifdef TLLM_ENABLE_CUDA -inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, - std::vector const& shapes, +inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, std::vector const& shapes, std::vector const& strides, std::vector const& tileShapes, void* gmemAddr, - bool doSwizzle = true) { + bool doPad, bool doSwizzle = true) { // The multiplication factor of the data padding in SMEM. int32_t padMultiplier = 1; CUtensorMap desc{}; @@ -56,12 +55,10 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, } else if (dtype == tg::Dtype::E2m1) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else if (dtype == tg::Dtype::MxE2m1 || dtype == tg::Dtype::MxInt4) { - if (mmaKind == tg::MmaKind::MxFp8Fp6Fp4) { + if (doPad) { padMultiplier = 2; tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else { - // Note: this is used with the MMA kind MxFp4NvFp4 and also when casting to a higher-precision - // type such as Bfloat16 before the MMA. tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } } else if (dtype == tg::Dtype::Fp32) { diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h index c8de154396..ba3275ec85 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h @@ -95,11 +95,13 @@ inline std::string mmaKindToString(MmaKind mmaKind) { //////////////////////////////////////////////////////////////////////////////////////////////////// -// function to get the TMEM column stride per group (i.e., 64 K elements) -inline int32_t getTmemColStridePerGroup(int32_t tileMn, int32_t mmaK) { - // Calculate the stride of TMEM column for every 64 elements in the K dimension - int32_t div = 2 * ceilDiv(tileMn, 64); - return mmaK == 96 ? std::max(4, div) : div; +// Get the TMEM column stride per group (i.e. kGroupSize * blockSize K elements) +inline int32_t getTmemColStridePerGroup(int32_t tileMn, int32_t mmaK, int32_t kGroupSize) { + int32_t colStride = 2 * ceilDiv(tileMn, 64); + if (mmaK == 96) { + colStride = std::max(4, colStride); + } + return colStride; } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index c209e5c509..01fdd235da 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2287,6 +2287,7 @@ def run_moe_test( weight_processing, gated_act_type, cache_permute_indices, + zero_hidden_states=False, ): """Common test logic for all routing methods.""" skip_checks( @@ -2297,6 +2298,7 @@ def run_moe_test( num_tokens, hidden_size, intermediate_size, + zero_hidden_states=zero_hidden_states, ) torch.cuda.synchronize() @@ -2340,7 +2342,8 @@ def run_moe_test( else: routing_bias = None - hidden_states = 2 * torch.randn( + hidden_states_fn = torch.zeros if zero_hidden_states else torch.randn + hidden_states = 2 * hidden_states_fn( (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 ) gemm1_weights = torch.randn( @@ -2471,6 +2474,13 @@ def run_moe_test( # Test: Renormalize routing +@pytest.mark.parametrize( + "zero_hidden_states", + [ + pytest.param(True, id="ZeroHiddenStates"), + pytest.param(False, id="RandomHiddenStates"), + ], +) @pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [1024, 768, 512, 384]) @@ -2606,6 +2616,7 @@ def test_renormalize_routing( weight_processing, gated_act_type, cache_permute_indices, + zero_hidden_states, ): """Test Renormalize routing configurations.""" run_moe_test( @@ -2617,6 +2628,7 @@ def test_renormalize_routing( weight_processing, gated_act_type, cache_permute_indices, + zero_hidden_states=zero_hidden_states, ) diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 5a8f932117..19c01d5175 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -41,14 +41,20 @@ def skip_checks( num_tokens, hidden_size, intermediate_size, + zero_hidden_states=False, ): """Common skip logic for all tests.""" compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] not in [10]: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") - # Check if moe_impl is FP4Moe by class name to avoid circular imports + # Check moe_impl class by name to avoid circular imports is_fp4_moe = type(moe_impl).__name__ == "FP4Moe" + is_fp8_block_scale_moe = type(moe_impl).__name__ == "FP8BlockScaleMoe" + + # Skip zero hidden states tests for non-FP8 Block Scale MoE implementations + if zero_hidden_states and not is_fp8_block_scale_moe: + pytest.skip("Skipping zero hidden states tests for non-FP8 Block Scale MoE.") # Skip incompatible combinations if gated_act_type == GatedActType.GeGlu and (