diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 2a991829dd..e7e40e772f 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -8,13 +8,109 @@ fp4_quantize, mxfp8_quantize, ) -from flashinfer.fused_moe import trtllm_fp4_block_scale_moe +from flashinfer.fused_moe import ( + trtllm_fp4_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe, +) from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time from flashinfer.utils import device_support_pdl +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +FLOAT4_E2M1_MAX = 6.0 + + +def fp8_quantize(x): + max = x.float().abs().nan_to_num().max() + scale = FLOAT8_E4M3_MAX / max + x = (x * scale).to(torch.float8_e4m3fn) + return x, 1.0 / scale -def bench_trtllm_gen_fused_moe_autotuner( + +def bench_trtllm_gen_fused_moe_autotuner_fp8( + tune_max_num_tokens: Optional[int], + quant_mode: Literal["Fp8-Per-Tensor"], + num_tokens: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + warmups: int, + iterations: int, +): + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) + hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( + torch.bfloat16 + ) + w13 = torch.randn( + num_experts, intermediate_size * 2, hidden_size, device=device + ).to(torch.bfloat16) + w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( + torch.bfloat16 + ) + + hidden_states, hidden_states_scale = fp8_quantize(hidden_states) + w13, w13_scale = fp8_quantize(w13) + w2, w2_scale = fp8_quantize(w2) + + output1_scale_scalar = torch.tensor( + [hidden_states_scale * w13_scale] * num_experts, device=device + ) + output1_scales_gate_scalar = torch.ones( + num_experts, device=device, dtype=torch.float32 + ) + output2_scale_scalar = torch.tensor( + [hidden_states_scale * w2_scale] * num_experts, device=device + ) + + fn = lambda: trtllm_fp8_per_tensor_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + w13, + output1_scale_scalar, + output1_scales_gate_scalar, + w2, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + 1.0, # routed_scaling_factor + False, # use_routing_scales_on_input + None, + RoutingMethodType.TopK.value, + enable_pdl, + num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + ) + + def bench(do_autotune): + with autotune(do_autotune): + fn() + ms_list = bench_gpu_time( + fn, + dry_run_iters=warmups, + repeat_iters=iterations, + ) + median_ms = np.median(ms_list) + return median_ms + + ms = bench(do_autotune=False) + ms_tuned = bench(do_autotune=True) + print( + f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}" + ) + print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") + + +def bench_trtllm_gen_fused_moe_autotuner_fp4( tune_max_num_tokens: Optional[int], quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], num_tokens: int, @@ -143,12 +239,11 @@ def bench_trtllm_gen_fused_moe_autotuner( ) def bench(do_autotune): - # warmup with autotune(do_autotune): - for _ in range(warmups): - fn() + fn() ms_list = bench_gpu_time( fn, + dry_run_iters=warmups, repeat_iters=iterations, ) median_ms = np.median(ms_list) @@ -168,7 +263,7 @@ def bench(do_autotune): "--quant-mode", type=str, default="MxFP4xMxFP8", - choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], + choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16", "Fp8-Per-Tensor"], help="Quantization mode", ) parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens") @@ -193,14 +288,27 @@ def bench(do_autotune): "--iterations", type=int, default=100, help="Number of benchmark iterations" ) args = parser.parse_args() - bench_trtllm_gen_fused_moe_autotuner( - args.tune_max_num_tokens, - args.quant_mode, - args.num_tokens, - args.num_experts, - args.hidden_size, - args.intermediate_size, - args.top_k, - args.warmups, - args.iterations, - ) + if args.quant_mode == "Fp8-Per-Tensor": + bench_trtllm_gen_fused_moe_autotuner_fp8( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + ) + else: + bench_trtllm_gen_fused_moe_autotuner_fp4( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + ) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index bf57fd5b9e..42fe8f7f59 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -144,6 +144,10 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( gemmData.mProblemDimensions.mWorldSize = 1; gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + auto bmm = BatchedGemmInterface(); auto const configs = bmm.getBatchedGemmConfigs(); @@ -239,6 +243,10 @@ void TrtllmGenBatchedGemmRunner::run( int32_t multiProcessorCount; cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere bmm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); @@ -327,6 +335,10 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( gemmData.mProblemDimensions.mWorldSize = 1; gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) { auto const& optionsA = configs[idx0].mOptions; auto const& optionsB = configs[idx1].mOptions; @@ -387,8 +399,7 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( // Filter out invalid configs. std::vector validConfigIndices; for (auto const& configIndex : prioritizedIndices) { - auto const& config = configs[configIndex]; - auto isValidConfig = bmm.isValidConfig(config, gemmData); + auto isValidConfig = bmm.isValidConfig(configs[configIndex], gemmData); if (isValidConfig) { validConfigIndices.push_back(configIndex); } @@ -435,7 +446,9 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t auto const& config = configs[configIndex]; - return bmm.isValidConfig(config, gemmData); + // FIXME: temporarily disable split-k as renormalize routing plus expert number 256 failed in + // trtllm-gen ac83afb + return bmm.isValidConfig(config, gemmData) && config.mOptions.mClusterDimZ == 1; } } // namespace kernels diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 538dc92725..3fd9dab35e 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -63,13 +63,22 @@ std::set computeSelectedTileN(std::vector const& supported_til int64_t const num_tokens, int64_t const top_k, int64_t const num_local_experts) { float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / num_local_experts; + // assume supported_tile_nums is sorted int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), supported_tile_nums.front(), supported_tile_nums.back()); - - std::set selected_tile_nums = { - std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, - std::min(supported_tile_nums.back(), tile_tokens_dim * 2), - std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; + auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); + + std::set selected_tile_nums; + selected_tile_nums.insert(tile_tokens_dim); + if (std::next(it) != supported_tile_nums.end()) { + selected_tile_nums.insert(*std::next(it)); + if (std::next(std::next(it)) != supported_tile_nums.end()) { + selected_tile_nums.insert(*std::next(std::next(it))); + } + } + if (it != supported_tile_nums.begin()) { + selected_tile_nums.insert(*std::prev(it)); + } return selected_tile_nums; } @@ -369,7 +378,7 @@ void trtllm_fp8_per_tensor_scale_moe( auto const hidden_size = hidden_states.size(1); bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8 - std::vector mSupportedTileN = {8, 16, 32, 64, 128}; + std::vector mSupportedTileN = {8, 16, 32, 64, 128, 192, 256}; std::set selected_tile_nums = computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); @@ -718,7 +727,7 @@ void trtllm_fp8_block_scale_moe( auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); - std::vector mSupportedTileN = {8, 16, 32, 64}; + std::vector mSupportedTileN = {8, 16, 32, 64, 128}; std::set selected_tile_nums = computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); @@ -1228,6 +1237,11 @@ Array trtllm_fp4_block_scale_moe( if (mDtypeAct != btg::Dtype::Bfloat16) { mSupportedTileN.push_back(128); } + if ((mDtypeAct == btg::Dtype::MxE4m3 && mDtypeWeights == btg::Dtype::MxE2m1) || + (mDtypeAct == btg::Dtype::E2m1 && mDtypeWeights == btg::Dtype::E2m1)) { + // MxFP4 x MxFP4 or NvFP4 x NvFP4 + mSupportedTileN.push_back(256); + } std::set selected_tile_nums = computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); // Build runners for all supported tile sizes @@ -1305,8 +1319,20 @@ Array> trtllm_get_valid_moe_configs( bool is_fp8_per_tensor = dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; - if (is_fp4_without_bf16_act || is_fp8_per_tensor) { + if (useDeepSeekFp8) { + supported_tile_nums.push_back(128); + } else if (is_fp8_per_tensor) { supported_tile_nums.push_back(128); + supported_tile_nums.push_back(192); + supported_tile_nums.push_back(256); + } else if (is_fp4_without_bf16_act) { + supported_tile_nums.push_back(128); + } + + if ((dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE2m1) || + (dtype_act == btg::Dtype::E2m1 && dtype_weights == btg::Dtype::E2m1)) { + // MxFP4 x MxFP4 or NvFP4 x NvFP4 + supported_tile_nums.push_back(256); } std::set selected_tile_nums = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 527924559d..7f9a664291 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -392,7 +392,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // Compute the runtime config for projections // Whether or not an expert is local is taken into account when smemExpertCount is computed // so we do not need to take it into account here. - const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } + int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); @@ -401,14 +408,31 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } // get the padded offset associated with this expert - const int32_t offset = mulLog2(ctaOffset, params.mPaddingLog2); - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } // write out padded count if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { @@ -542,8 +566,6 @@ void runImpl(Data& data, void* stream) { } FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", - data.mPaddingLog2); int const numBlocks = data.mNumTokens; int const numThreadsHist = getMaxNumExperts(data.mNumExperts); diff --git a/csrc/trtllm_fused_moe_routing_llama4.cu b/csrc/trtllm_fused_moe_routing_llama4.cu index ebdd0b8720..13ca041644 100644 --- a/csrc/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/trtllm_fused_moe_routing_llama4.cu @@ -189,7 +189,13 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); - numCta += divUpLog2(count, params.mPaddingLog2); + int32_t num; + if constexpr (KernelParams::isPow2) { + num = divUpLog2(count, params.mPaddingLog2); + } else { + num = divUpTileN(count, params.mTileTokensDim); + } + numCta += num; } // second, we perform the exclusive sum across the warp int32_t ctaOffset; @@ -202,22 +208,39 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); - auto finalNumCta = divUpLog2(count, params.mPaddingLog2); + int32_t finalNumCta; + if constexpr (KernelParams::isPow2) { + finalNumCta = divUpLog2(count, params.mPaddingLog2); + } else { + finalNumCta = divUpTileN(count, params.mTileTokensDim); + } auto expertIdx = threadIdx.x * ExpertsPerThread + ii; // during the scan for expert offsets, we can already write out // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = - min(mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffsetExp, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffsetExp, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffsetExp + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffsetExp, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = min(mnLimit1, mnLimit2); } ctaOffsetExp += finalNumCta; } // at this point, we can write out padded count from the warp-aggregate if (cute::elect_one_sync()) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -236,12 +259,20 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; int32_t finalExpertOffset[ExpertsPerThread]; - finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); + if constexpr (KernelParams::isPow2) { + finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + finalExpertOffset[0] = mulTileN(ctaOffset, params.mTileTokensDim); + } #pragma unroll for (int ii = 1; ii < ExpertsPerThread; ++ii) { - finalExpertOffset[ii] = - finalExpertOffset[ii - 1] + - divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); + int32_t tmp; + if constexpr (KernelParams::isPow2) { + tmp = divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); + } else { + tmp = divUpMulTileN(getBits(expertCount, ii - 1), params.mTileTokensDim); + } + finalExpertOffset[ii] = finalExpertOffset[ii - 1] + tmp; } #pragma unroll @@ -455,8 +486,6 @@ void runImpl(Data const& data, void* stream) { NumExpertsLimit); FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", - data.mPaddingLog2); bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 1a4823d481..56939f8d02 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -165,14 +165,24 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } __syncthreads(); // Get the number of CTAs and the offset for each CTA - const int32_t numCta = divUpLog2(accExpertCount, params.mPaddingLog2); + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(accExpertCount, params.mPaddingLog2); + } else { + numCta = divUpTileN(accExpertCount, params.mTileTokensDim); + } int32_t ctaOffset = 0; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); int32_t expertScanCounts = 0; - Scan(tempStorage) - .ExclusiveSum(divUpMulLog2(accExpertCount, params.mPaddingLog2), expertScanCounts); + int32_t tmpCount; + if constexpr (KernelParams::isPow2) { + tmpCount = divUpMulLog2(accExpertCount, params.mPaddingLog2); + } else { + tmpCount = divUpMulTileN(accExpertCount, params.mTileTokensDim); + } + Scan(tempStorage).ExclusiveSum(tmpCount, expertScanCounts); __syncthreads(); if (isLocalExpert) { @@ -180,15 +190,27 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + accExpertCount; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } } // at this point, we can write out padded count if (threadIdx.x == 0) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -399,8 +421,6 @@ void run(Data const& data, void* stream) { << NumExpertsLimit << "."; TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; - TVM_FFI_ICHECK_LE(data.mPaddingLog2, 8) - << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2; bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens; diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index a33843516e..21a2cad4b5 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -23,7 +23,6 @@ #include "flashinfer/trtllm/fused_moe/DevKernel.h" #include "flashinfer/trtllm/fused_moe/RoutingKernel.h" #include "flashinfer/trtllm/fused_moe/runner.h" -// #include namespace tensorrt_llm { namespace kernels { @@ -39,7 +38,9 @@ inline int32_t computeLog2(int32_t val, std::string const& name = "") { while (n >>= 1) { ++out; } - FLASHINFER_CHECK((1 << out) == val, "Expected ", name, " to be a power of 2, got ", val); + if ((1 << out) != val) { + out = -1; + } return out; } } // namespace @@ -90,6 +91,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumLimitedGroups = topkGroup; routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; @@ -124,6 +126,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumExperts = numExperts; routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; @@ -170,6 +173,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumExperts = numExperts; routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 25f679968f..733b7aed24 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -89,7 +89,7 @@ class ArtifactPath: TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23" + "23daeee32b60bde7947ce1ee7a58d4ab701f134b/batched_gemm-0d28130-add42d1" ) TRTLLM_GEN_GEMM: str = ( "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" @@ -105,7 +105,7 @@ class MetaInfoHash: "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" ) TRTLLM_GEN_BMM: str = ( - "4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152" + "6cfade1395f9648aba5dcf2c329114619e175c0f238882555178f98c8f5c1968" ) TRTLLM_GEN_GEMM: str = ( "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340" @@ -123,7 +123,7 @@ class CheckSumHash: "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" ) TRTLLM_GEN_BMM: str = ( - "8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934" + "46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd" ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index f8af220916..a82fabd8c0 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -483,7 +483,7 @@ def choose_one( except Exception as e: shapes = self._get_input_sizes(tensors) logger.warning( - f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling." + f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling: {e}" ) # Log stacktrace as debug to not spam log diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index c91878ca0e..3ea148c780 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1676,9 +1676,10 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts: int, routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, - tile_tokens_dim: int = 8, + tile_tokens_dim: Optional[int] = None, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -1700,9 +1701,10 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts: Number of experts handled by this device routed_scaling_factor: Scaling factor for routing use_routing_scales_on_input: Whether to use routing scales on input - tile_tokens_dim: Tile dimension for tokens (default: 8) + tile_tokens_dim: Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] @@ -1733,6 +1735,7 @@ def trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input, routing_method_type, enable_pdl, + tune_max_num_tokens, ) @@ -1758,6 +1761,7 @@ def trtllm_fp8_block_scale_moe( use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: """FP8 block scale MoE operation. @@ -1778,9 +1782,10 @@ def trtllm_fp8_block_scale_moe( local_expert_offset: Offset of local experts in global expert space local_num_experts: Number of experts handled by this device routed_scaling_factor: Scaling factor for routing - tile_tokens_dim: Tile dimension for tokens (default: 8) + tile_tokens_dim: Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ @@ -1815,6 +1820,7 @@ def trtllm_fp8_block_scale_moe( use_shuffled_weight, weight_layout, enable_pdl, + tune_max_num_tokens, ) diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 11398fabd9..78c19e98ac 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -233,11 +233,12 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: ], extra_cuda_cflags=[ "-DTLLM_GEN_EXPORT_INTERFACE", + "-DTLLM_GEN_EXPORT_FLASHINFER", "-DTLLM_ENABLE_CUDA", "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP4", - f'-DTLLM_GEN_BMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"', + f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"', ] + nvcc_flags, extra_include_paths=[ diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 6564aefa35..7873d0de14 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -381,6 +381,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec: ], extra_cuda_cflags=[ "-DTLLM_GEN_EXPORT_INTERFACE", + "-DTLLM_GEN_EXPORT_FLASHINFER", "-DTLLM_ENABLE_CUDA", f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"', ] @@ -531,6 +532,7 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec: ], extra_cuda_cflags=[ "-DTLLM_GEN_EXPORT_INTERFACE", + "-DTLLM_GEN_EXPORT_FLASHINFER", "-DTLLM_ENABLE_CUDA", f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"', ] diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h index 27955d2bdc..919d6cb00d 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h @@ -31,7 +31,9 @@ enum class RouteImpl { // Use LDGSTS to do the routing Ldgsts = 1, // Use UTMALDG.GATHER4 to do the routing - Tma = 2 + Tma = 2, + // Use LDG+STS to do the routing + LdgPlusSts = 3 }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,6 +50,10 @@ inline bool doesRouteImplUseTma(RouteImpl mode) { return (mode == RouteImpl::Tma //////////////////////////////////////////////////////////////////////////////////////////////////// +inline bool doesRouteImplUseLdgPlusSts(RouteImpl mode) { return (mode == RouteImpl::LdgPlusSts); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace batchedGemm //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index 6b1f910178..f93f20d28e 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -24,18 +24,12 @@ #include "trtllm/gen/CudaKernelLauncher.h" #ifdef TLLM_GEN_EXPORT_INTERFACE +#ifdef TLLM_GEN_EXPORT_FLASHINFER #include "flashinferMetaInfo.h" -#endif // TLLM_GEN_EXPORT_INTERFACE - -#ifdef TLLM_GEN_BMM_CUBIN_PATH -static const std::string tllm_gen_bmm_cubin_path = std::string(TLLM_GEN_BMM_CUBIN_PATH); #else -static_assert(false, "TLLM_GEN_BMM_CUBIN_PATH macro is not defined when compiling"); -#endif - -namespace flashinfer::trtllm_cubin_loader { -std::string getCubin(const std::string& kernelName, const std::string& sha256); -} +#include "KernelMetaInfo.h" +#endif // TLLM_GEN_EXPORT_FLASHINFER +#endif // TLLM_GEN_EXPORT_INTERFACE namespace batchedGemm { @@ -79,13 +73,18 @@ struct BatchedGemmData { // The M dimension. // It is the total number of tokens if A is the activation matrix. // It is the total number of output channels if A is the weight matrix. + // ValidM/N/K by default assumes to be full range of M/N/K respectively. If we pad M/N/K due to + // alignment of other constraints, then we can specify ValidM/N/K to indicate the valid range. int32_t mM{0}; + int32_t mValidM{0}; // The N dimension. // It is the total number of tokens if B is the activation matrix. // It is the total number of output channels if B is the weight matrix. int32_t mN{0}; + int32_t mValidN{0}; // The K dimension. It is the hidden dimension of the input matrices. int32_t mK{0}; + int32_t mValidK{0}; // The rank id of the current device in the multi-gpu space. int32_t mRank{0}; // The number of devices in tensor-parallel group. @@ -457,28 +456,187 @@ class BatchedGemmInterface { public: using ModuleCache = std::unordered_map>; - BatchedGemmInterface() {} + ////////////////////////////////////////////////////////////////////////////////////////////////// + + BatchedGemmInterface(bool const exportsCubin = false, int32_t const numRotations = 1) + : mExportsCubin(exportsCubin), mNumRotations(numRotations) {} + + ////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifndef TLLM_GEN_EXPORT_INTERFACE + // Generates and compiles the kernel using either nvcc or nvrtc. + BatchedGemmConfig generateAndCompileKernel(BatchedGemmConfig const& batchedGemmConfig) const; +#endif + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Provided config must be validated with isValidConfig before the call. - int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, - void* cudaStream, int32_t multiProcessorCount, bool usePdl = true, - std::optional> moduleCache = std::nullopt); + int32_t run(BatchedGemmConfig const& config, void* workspace, + BatchedGemmData const& batchedGemmData, void* cudaStream, + int32_t /*multiProcessorCount*/, bool usePdl = true, + std::optional> moduleCache = std::nullopt) { + // Might be used. + (void)usePdl; + (void)moduleCache; + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, batchedGemmData); + + bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; + bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && + options.mDtypeB == tg::Dtype::E4m3; + + auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); + float* dPtrRowMax{nullptr}; + uint32_t* dPtrRowMaxBars{nullptr}; + + // Set the completion barriers to 0 if needed. + if (useDeepSeekFp8 && options.mFusedAct) { + dPtrRowMax = reinterpret_cast(alignPtr(reinterpret_cast(workspace), 1024)); + dPtrRowMaxBars = reinterpret_cast( + alignPtr(reinterpret_cast(dPtrRowMax) + workspaceSizes[0], 1024)); + auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], + reinterpret_cast(cudaStream)); + if (err != cudaSuccess) { + return 1; + } + } + + auto [numCtaBatch, numCtaTile, numCtaInner] = + getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); + auto kernelParams = KernelParamsSetup::setKernelParams( + options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, + batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, + batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, + batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, + batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, + batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, + batchedGemmData.mInputBuffers.mPtrGatedActAlpha, + batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, + dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, + batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, + batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); + + // The size of the grid. + std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} + : std::vector{numCtaTile, numCtaBatch, numCtaInner}; + + BatchedGemmConfig batchedGemmConfig = config; +#ifndef TLLM_GEN_EXPORT_INTERFACE + // Generate and compile the kernel if data is not provided. + if (config.mData == nullptr) { + batchedGemmConfig = generateAndCompileKernel(batchedGemmConfig); + } + TLLM_CHECK_ERROR(batchedGemmConfig.mCudaRunner != nullptr, "CudaRunner is not set"); + batchedGemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid, + /* cluster */ {}, + /* instanceId */ batchedGemmConfig.mInstanceIdx); + return 0; +#endif + + CUmodule cuModule; + CUfunction cuFunction; + + if (moduleCache.has_value()) { + ModuleCache& moduleCacheRef = moduleCache.value().get(); + + // Modules are associated with a specific context, so the context is included in the key + CUcontext ctx; + unsigned long long ctxId; + cuCtxGetCurrent(&ctx); + cuCtxGetId(ctx, &ctxId); + + // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a + // string in decimal representation. + std::string const ctxName = + std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); + std::string const funcName = std::string(batchedGemmConfig.mFunctionName); + auto const moduleKey = ctxName + funcName; + auto module = moduleCacheRef.find(moduleKey); + + // Use cache if module is found, otherwise load and insert into cache + if (module != moduleCacheRef.end()) { + cuFunction = std::get<1>(module->second); + } else { + gemm::loadCubinData(&cuModule, batchedGemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); + } + } else { + gemm::loadCubinData(&cuModule, batchedGemmConfig); + cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + } + + // Prepare the grid/block. + dim3 block3{static_cast(batchedGemmConfig.mNumThreadsPerCTA), + static_cast(1), static_cast(1)}; + dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), + (grid.size() > 1 ? static_cast(grid[1]) : 1u), + (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; + // Prepare the cluster size. + dim3 cluster3{static_cast(options.mClusterDimX), + static_cast(options.mClusterDimY), + static_cast(options.mClusterDimZ)}; + + // Run the kernel. + auto result = trtllm::gen::launchKernel( + (void*)&kernelParams, cudaStream, batchedGemmConfig.mSharedMemSize, cuFunction, block3, + grid3, cluster3, + usePdl && (batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit | + batchedGemmConfig.mOptions.mGridWaitForPrimaryA | + batchedGemmConfig.mOptions.mGridWaitForPrimaryB)); + if (result != CUDA_SUCCESS) { + return -1; + } + // If a module cache has not been given, unload the module to avoid leaking + if (!moduleCache.has_value()) { + cuModuleUnload(cuModule); + } + return 0; + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Initializes the buffers before the world sync. Must be called before run. int32_t runInitBeforeWorldSync(BatchedGemmConfig const& /* config */, BatchedGemmData const& /* data */, void* /* cudaStream */) const { return 0; - }; + } - size_t getWorkspaceSizeInBytes(BatchedGemmConfig const& /* config */, - BatchedGemmData const& /* data */) const; + ////////////////////////////////////////////////////////////////////////////////////////////////// + + size_t getWorkspaceSizeInBytes(BatchedGemmConfig const& config, + BatchedGemmData const& data) const { + auto workspaceSizes = getWorkspaceSizesInBytes(config, data); + auto size = std::accumulate(workspaceSizes.begin(), workspaceSizes.end(), 0); + // Additional 1023 bytes to align the pointer to 1024 + return size > 0 ? size + 1023 : 0; + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Returns the list of all available cubin configurations - BatchedGemmConfig const* getBatchedGemmConfigs() const; + BatchedGemmConfig const* getBatchedGemmConfigs() const { +#ifdef TLLM_GEN_EXPORT_INTERFACE + return tensorrt_llm::kernels::tllmGenBatchedGemmList; +#else + return nullptr; +#endif + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Returns the number of available cubin configurations - size_t getNumBatchedGemmConfigs() const; + size_t getNumBatchedGemmConfigs() const { +#ifdef TLLM_GEN_EXPORT_INTERFACE + return tensorrt_llm::kernels::tllmGenBatchedGemmListLen; +#else + return 0; +#endif + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// // Returns the grid dimensions of the current kernel. std::tuple getGridDim( @@ -523,6 +681,8 @@ class BatchedGemmInterface { return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner); } + ////////////////////////////////////////////////////////////////////////////////////////////////// + // Returns the number of CTAs of the current kernel. int32_t getNumCtas(BatchedGemmOptions const& options, std::optional maxNumCtasInBatchDim = std::nullopt) const { @@ -530,278 +690,117 @@ class BatchedGemmInterface { return numCtasBatch * numCtasTile * numCtasInner; } - // Returns true if the configuration of the cubin can be executed for the given params. - bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const; + ////////////////////////////////////////////////////////////////////////////////////////////////// // Creates GemmOptions from kernel and data. BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, - BatchedGemmData const& data) const; - - private: - // Aligns the pointer to the alignment - template - inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; - - // Returns the size of the workspace buffers in bytes - std::vector getWorkspaceSizesInBytes(BatchedGemmConfig const& config, - BatchedGemmData const& data) const; - - // Returns the size padded to the alignment - size_t getSizePaddedToAlignment(size_t size, size_t alignment) const; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline Dtype* BatchedGemmInterface::alignPtr(Dtype* ptr, int64_t alignment) const { - assert((alignment & (alignment - 1)) == 0 && "Alignment must be a power of 2"); - return reinterpret_cast((reinterpret_cast(ptr) + alignment - 1) & - ~(alignment - 1)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BatchedGemmConfig const* BatchedGemmInterface::getBatchedGemmConfigs() const { -#ifdef TLLM_GEN_EXPORT_INTERFACE - return tensorrt_llm::kernels::tllmGenBatchedGemmList; -#else - return nullptr; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -size_t BatchedGemmInterface::getNumBatchedGemmConfigs() const { -#ifdef TLLM_GEN_EXPORT_INTERFACE - return sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList) / - sizeof(tensorrt_llm::kernels::tllmGenBatchedGemmList[0]); -#else - return 0; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BatchedGemmOptions BatchedGemmInterface::getOptionsFromConfigAndData( - BatchedGemmConfig const& config, BatchedGemmData const& data) const { - // Create options from config and data. - BatchedGemmOptions options; - options = config.mOptions; - options.mM = data.mProblemDimensions.mM; - options.mN = data.mProblemDimensions.mN; - options.mK = data.mProblemDimensions.mK; - options.mBatchedM = data.mProblemDimensions.mBatchedM; - options.mBatchedN = data.mProblemDimensions.mBatchedN; - options.mBatchMode = data.mProblemDimensions.mBatchM ? BatchedGemmOptions::BatchMode::BatchM - : BatchedGemmOptions::BatchMode::BatchN; - options.mNumBatches = data.mProblemDimensions.mNumBatches; - options.mNumTokens = data.mProblemDimensions.mNumTokens; - return options; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -bool BatchedGemmInterface::isValidConfig(BatchedGemmConfig const& config, - BatchedGemmData const& data) const { - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, data); - - // Is Blackwell? - bool isBlackwell = gemm::isSmVersionBlackwell(config.mSm); + BatchedGemmData const& data) const { + BatchedGemmOptions options; + options = config.mOptions; + options.mM = data.mProblemDimensions.mM; + options.mN = data.mProblemDimensions.mN; + options.mK = data.mProblemDimensions.mK; + options.mValidM = data.mProblemDimensions.mValidM; + options.mValidN = data.mProblemDimensions.mValidN; + options.mValidK = data.mProblemDimensions.mValidK; + options.mBatchedM = data.mProblemDimensions.mBatchedM; + options.mBatchedN = data.mProblemDimensions.mBatchedN; + options.mBatchMode = data.mProblemDimensions.mBatchM ? BatchedGemmOptions::BatchMode::BatchM + : BatchedGemmOptions::BatchMode::BatchN; + options.mNumBatches = data.mProblemDimensions.mNumBatches; + options.mNumTokens = data.mProblemDimensions.mNumTokens; + return options; + } - // Check options without modifications. - return checkAndUpdateBatchedGemmOptions(options, isBlackwell, - /* updateOptions */ false); -} + ////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Returns true if the configuration of the cubin can be executed for the given params. + bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const { + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, data); -size_t BatchedGemmInterface::getSizePaddedToAlignment(size_t size, size_t alignment) const { - assert((alignment & (alignment - 1)) == 0); - return (size + alignment - 1) & ~(alignment - 1); -} + // Is Blackwell? + bool isBlackwell = gemm::isSmVersionBlackwell(config.mSm); -//////////////////////////////////////////////////////////////////////////////////////////////////// + // Check options without modifications. + return checkAndUpdateBatchedGemmOptions(options, isBlackwell, + /* updateOptions */ false); + } -size_t BatchedGemmInterface::getWorkspaceSizeInBytes(BatchedGemmConfig const& config, - BatchedGemmData const& data) const { - auto workspaceSizes = getWorkspaceSizesInBytes(config, data); - auto size = std::accumulate(workspaceSizes.begin(), workspaceSizes.end(), 0); - // Additional 1023 bytes to align the pointer to 1024 - return size > 0 ? size + 1023 : 0; -} + ////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// + private: + ////////////////////////////////////////////////////////////////////////////////////////////////// -std::vector BatchedGemmInterface::getWorkspaceSizesInBytes( - BatchedGemmConfig const& config, BatchedGemmData const& data) const { - std::vector workspaceSizes; + template + inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const { + assert((alignment & (alignment - 1)) == 0 && "Alignment must be a power of 2"); + return reinterpret_cast((reinterpret_cast(ptr) + alignment - 1) & + ~(alignment - 1)); + } - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, data); + ////////////////////////////////////////////////////////////////////////////////////////////////// - if (options.mUseDeepSeekFp8 && options.mFusedAct) { - int32_t totalNumPaddedTokens = 0; - auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; - if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { - for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { - totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) - : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); + // Returns the size of the workspace buffers in bytes + std::vector getWorkspaceSizesInBytes(BatchedGemmConfig const& config, + BatchedGemmData const& data) const { + std::vector workspaceSizes; + + // Get options from config and data. + auto options = getOptionsFromConfigAndData(config, data); + + if (options.mUseDeepSeekFp8 && options.mFusedAct) { + int32_t totalNumPaddedTokens = 0; + auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; + if (!options.mEnablesEarlyExit || options.mNumTokens == 0) { + for (int32_t bi = 0; bi < options.mNumBatches; ++bi) { + totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM) + : gemm::divUpMul(options.mBatchedN[bi], options.mTileN); + } + } else { + // Get tile in token dim. + auto tileTokensDim = batchM ? options.mTileM : options.mTileN; + totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; } - } else { - // Get tile in token dim. - auto tileTokensDim = batchM ? options.mTileM : options.mTileN; - totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim; - } - - // Get options from config. - auto& options = config.mOptions; - int const tokenTile = batchM ? options.mTileM : options.mTileN; + // Get options from config. + auto& options = config.mOptions; - auto const numTokens = totalNumPaddedTokens; - auto const intermediateDim = batchM ? options.mN : options.mM; - auto const intermediateTile = batchM ? options.mTileN : options.mTileM; + int const tokenTile = batchM ? options.mTileM : options.mTileN; - auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); + auto const numTokens = totalNumPaddedTokens; + auto const intermediateDim = batchM ? options.mN : options.mM; + auto const intermediateTile = batchM ? options.mTileN : options.mTileM; - auto const numTilesToken = numTokens / tokenTile; - auto const numTilesInt = intermediateDim / intermediateTile; - auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); - - // TODO: do we need to pad to 1024? - workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); - workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); - } + auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float); - return workspaceSizes; -} + auto const numTilesToken = numTokens / tokenTile; + auto const numTilesInt = intermediateDim / intermediateTile; + auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t); -//////////////////////////////////////////////////////////////////////////////////////////////////// -int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, - BatchedGemmData const& batchedGemmData, void* cudaStream, - int32_t /* multiProcessorCount */, bool usePdl, - std::optional> moduleCache) { - // Might be used. - (void)usePdl; - (void)moduleCache; - // Get options from config and data. - auto options = getOptionsFromConfigAndData(config, batchedGemmData); - - bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM; - bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 && - options.mDtypeB == tg::Dtype::E4m3; - - auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData); - float* dPtrRowMax{nullptr}; - uint32_t* dPtrRowMaxBars{nullptr}; - - // Set the completion barriers to 0 if needed. - if (useDeepSeekFp8 && options.mFusedAct) { - dPtrRowMax = reinterpret_cast(alignPtr(reinterpret_cast(workspace), 1024)); - dPtrRowMaxBars = reinterpret_cast( - alignPtr(reinterpret_cast(dPtrRowMax) + workspaceSizes[0], 1024)); - auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1], - reinterpret_cast(cudaStream)); - if (err != cudaSuccess) { - return 1; + // TODO: do we need to pad to 1024? + workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024)); + workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024)); } - } - - auto [numCtaBatch, numCtaTile, numCtaInner] = - getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim); - auto kernelParams = KernelParamsSetup::setKernelParams( - options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB, - batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA, - batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA, - batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias, - batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC, - batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit, - batchedGemmData.mInputBuffers.mPtrGatedActAlpha, - batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, - dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, - batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, - batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch); - - // The size of the grid. - std::vector grid = batchM ? std::vector{numCtaBatch, numCtaTile, numCtaInner} - : std::vector{numCtaTile, numCtaBatch, numCtaInner}; -#ifdef TLLM_GEN_EXPORT_INTERFACE - CUmodule cuModule; - CUfunction cuFunction; - - auto fiModuleLoadData = [&](CUmodule* module) { - const std::string sha256 = config.mHash ? config.mHash : ""; - std::string fname_cubin = config.mFunctionName; - if (!fname_cubin.empty()) { - fname_cubin[0] = static_cast(std::toupper(static_cast(fname_cubin[0]))); - } - fname_cubin = tllm_gen_bmm_cubin_path + "/" + fname_cubin + ".cubin"; - std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256); - cuModuleLoadData(&cuModule, cubin.c_str()); - }; - - if (moduleCache.has_value()) { - ModuleCache& moduleCacheRef = moduleCache.value().get(); - - // Modules are associated with a specific context, so the context is included in the key - CUcontext ctx; - unsigned long long ctxId; - cuCtxGetCurrent(&ctx); - cuCtxGetId(ctx, &ctxId); - - // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a - // string in decimal representation. - std::string const ctxName = - std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); - std::string const funcName = std::string(config.mFunctionName); - auto const moduleKey = ctxName + funcName; - auto module = moduleCacheRef.find(moduleKey); - - // Use cache if module is found, otherwise load and insert into cache - if (module != moduleCacheRef.end()) { - cuFunction = std::get<1>(module->second); - } else { - fiModuleLoadData(&cuModule); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); - moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); - } - } else { - fiModuleLoadData(&cuModule); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + return workspaceSizes; } - // Prepare the grid/block. - dim3 block3{static_cast(config.mNumThreadsPerCTA), static_cast(1), - static_cast(1)}; - dim3 grid3{(grid.size() > 0 ? static_cast(grid[0]) : 1u), - (grid.size() > 1 ? static_cast(grid[1]) : 1u), - (grid.size() > 2 ? static_cast(grid[2]) : 1u)}; - // Prepare the cluster size. - dim3 cluster3{static_cast(options.mClusterDimX), - static_cast(options.mClusterDimY), - static_cast(options.mClusterDimZ)}; - - // Run the kernel. - auto result = trtllm::gen::launchKernel( - (void*)&kernelParams, cudaStream, config.mSharedMemSize, cuFunction, block3, grid3, cluster3, - usePdl && (config.mOptions.mGridWaitForPrimaryEarlyExit | - config.mOptions.mGridWaitForPrimaryA | config.mOptions.mGridWaitForPrimaryB)); - if (result != CUDA_SUCCESS) { - return -1; - } - // If a module cache has not been given, unload the module to avoid leaking - if (!moduleCache.has_value()) { - cuModuleUnload(cuModule); + ////////////////////////////////////////////////////////////////////////////////////////////////// + + // Returns the size padded to the alignment + size_t getSizePaddedToAlignment(size_t size, size_t alignment) const { + assert((alignment & (alignment - 1)) == 0); + return (size + alignment - 1) & ~(alignment - 1); } -#else - config.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid); -#endif + ////////////////////////////////////////////////////////////////////////////////////////////////// - return 0; -} + private: + // Whether to export the cubin file. + bool mExportsCubin; + // The number of rotations. + int32_t mNumRotations; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index 07dcd30be4..6e53d00c17 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -55,6 +55,13 @@ namespace batchedGemm { +namespace trtllm { +namespace gen { +class CudaRunner; +class GenCfg; +} // namespace gen +} // namespace trtllm + namespace batchedGemm { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -80,42 +87,47 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, - int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, - gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, - int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps, - int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, - int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, - int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + int epilogueTileN, bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, + bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, + gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, + int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, + int numEpilogueWarps, int numRegsCastAWarps, int numRegsCopySfLdsSttm, + int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, + int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, - bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, - bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, - gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector batchedM, - std::vector batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch, - int numTokens, RouteImpl routeImpl, std::optional routeSfsImpl, - bool gridWaitForPrimaryRouting, bool fusedAct, bool useTmaOobOpt) + bool useMaxTmemOverlap, bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, + bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, + int validM, int validN, int validK, int worldSize, + // GemmGatedActOptions + gemmGatedAct::ActType actType, bool clampBeforeAct, + // BatchedGemmOptions + std::vector batchedM, std::vector batchedN, BatchMode batchMode, bool fusedAct, + bool gridWaitForPrimaryRouting, bool isStaticBatch, int numBatches, int numRegsPerThreadLoadB, + int numRegsPerThreadLoadSfB, int numTokens, int numWarpsLoadB, int numWarpsLoadSfB, + RouteImpl routeImpl, std::optional routeSfsImpl, bool useTmaOobOpt) : gemmGatedAct::GemmGatedActOptions( gemm::GemmOptions( allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, - epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA, - gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, - gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, - layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, - numRegsCopySfLdsSttm, numSlicesForSplitK, numSlicesForSliceK, numStages, - numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, - numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, - sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, - transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, - useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, - useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps, - useUnrollLoop2xForMma, worldSize), + epilogueLdtmBits, epilogueTileM, epilogueTileN, fuseUtccpWithUtcmma, + gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, + gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, + k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, + numEpilogueWarps, numRegsCastAWarps, numRegsCopySfLdsSttm, + numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, + numSlicesForSliceK, numStages, numStagesMma, numStagesMmaWithinWorkTile, + numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp, + sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK, splitK, + tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, + useDeepSeekFp8, useHoistTryWaitForCustomMmaSchedule, useMaxTmemOverlap, + usePerTokenSfA, usePerTokenSfB, useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, + useTwoMmaWarps, useUnrollLoop2xForMma, validM, validN, validK, worldSize), actType, clampBeforeAct), mBatchedM(batchedM), mBatchedN(batchedN), @@ -124,10 +136,11 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting), mIsStaticBatch(isStaticBatch), mNumBatches(numBatches), - mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), - mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), - mNumRegsCastAWarps(numRegsCastAWarps), + mNumRegsPerThreadLoadB{numRegsPerThreadLoadB}, + mNumRegsPerThreadLoadSfB{numRegsPerThreadLoadSfB}, mNumTokens(numTokens), + mNumWarpsLoadB{numWarpsLoadB}, + mNumWarpsLoadSfB{numWarpsLoadSfB}, mRouteImpl(routeImpl), mRouteSfsImpl(routeSfsImpl), mUseTmaOobOpt(useTmaOobOpt) {} @@ -147,14 +160,16 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { bool mIsStaticBatch{true}; // Number of Gemm batches. int mNumBatches; - // Number of registers per thread for non-epilogue warps - int mNumRegsPerThreadNonEpilogueWarp{0}; - // Number of registers per thread for epilogue warps - int mNumRegsPerThreadEpilogueWarp{0}; - // Number of registers for the cast A warps. - int mNumRegsCastAWarps{0}; + // Number of registers per thread for load B + int mNumRegsPerThreadLoadB{0}; + // Number of registers per thread for load SfB + int mNumRegsPerThreadLoadSfB{0}; // Total number of tokens. int mNumTokens{32}; + // Number of warps for load B + int mNumWarpsLoadB{0}; + // Number of warps for load SfB + int mNumWarpsLoadSfB{0}; // Whether load the input tokens and do routing. RouteImpl mRouteImpl{RouteImpl::NoRoute}; // Routing logic for scaling factors. If not specified, mRouteImpl is used. @@ -167,8 +182,8 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { //////////////////////////////////////////////////////////////////////////////////////////////////// // Check if the options are valid or not. -bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell, - bool updateOptions = true) { +inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackwell, + bool updateOptions = true) { bool isValid = true; if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps) { if (updateOptions) { @@ -222,19 +237,21 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw if (options.mUseDeepSeekFp8) { if (batchM) { // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mN % 128 == 0, - "GEMM-N must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mN); + TLLM_CHECK_ERROR( + options.mN % 128 == 0 && options.mValidN % 128 == 0, + "GEMM-N and validN must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mN, + " and validN=", options.mValidN); } else { // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mM % 128 == 0, - "GEMM-N must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mN); + TLLM_CHECK_ERROR( + options.mM % 128 == 0 && options.mValidM % 128 == 0, + "GEMM-M and validM must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mM, + " and validM=", options.mValidM); } // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mK % 128 == 0, - "GEMM-K must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mK); + TLLM_CHECK_ERROR(options.mK % 128 == 0 && options.mValidK % 128 == 0, + "GEMM-K and validK must be a multiple of 128 when using DeepSeek Fp8. Found ", + options.mK, " and validK=", options.mValidK); TLLM_CHECK_ERROR(options.mDtypeC != tg::Dtype::E2m1 && options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, @@ -243,8 +260,10 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() != options.mRouteImpl) { TLLM_CHECK_ERROR( - options.mRouteSfsImpl.value() == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma, - "RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma"); + (options.mRouteSfsImpl.value() == RouteImpl::Ldgsts || + options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts) && + options.mRouteImpl == RouteImpl::Tma, + "RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts, when RouteImpl is Tma"); } else if (!options.mRouteSfsImpl.has_value()) { if (updateOptions) { options.mRouteSfsImpl = options.mRouteImpl; @@ -253,6 +272,16 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw return false; } } + + TLLM_CHECK_ERROR(options.mRouteImpl != RouteImpl::LdgPlusSts, + "LdgPlusSts does not support routing the tokens"); + + if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts) { + TLLM_CHECK_ERROR(!batchM, "LdgPlusSts only supports batch N"); + TLLM_CHECK_ERROR(options.mTileK <= 512 && options.mTileK >= 128, + "LdgPlusSts only supports 128 <= tileK <= 512"); + } + if (batchM) { if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(doesRouteImplUseNoRoute(options.mRouteImpl), @@ -326,6 +355,7 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw "2CTA BatchedGemm does not support routing along M dimension. To support it, " "change the input routing data layout to be padded to clusterDimX size."); } + return isValid; } @@ -336,19 +366,18 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw //////////////////////////////////////////////////////////////////////////////////////////////////// struct BatchedGemmConfig { - // When TRT-LLM Gen is exported to the other frameworks, the TLLM_GEN_EXPORT_INTERFACE must be - // defined. In this case, the cubins will be loaded from the provided data and function name. - // Otherwise, the kernel will be loaded from the CudaRunner. -#ifdef TLLM_GEN_EXPORT_INTERFACE uint8_t const* mData{nullptr}; - uint32_t const mSize{0}; - uint32_t const mSharedMemSize{0}; + uint32_t mSize{0}; + uint32_t mSharedMemSize{0}; char const* mFunctionName{nullptr}; - uint32_t const mNumThreadsPerCTA{0}; + uint32_t mNumThreadsPerCTA{0}; char const* mHash{nullptr}; -#else + + std::string mGenCfgJsonStr{""}; + char const* mExecPath{nullptr}; trtllm::gen::CudaRunner* mCudaRunner{nullptr}; -#endif + trtllm::gen::GenCfg* mGenCfg{nullptr}; + int32_t mInstanceIdx{0}; BatchedGemmOptions mOptions; gemm::SmVersion mSm{gemm::SmVersion::Sm100a}; @@ -356,27 +385,32 @@ struct BatchedGemmConfig { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline std::string dumpOptions(BatchedGemmOptions const& options) { +inline std::string dumpOptions(BatchedGemmOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; - ss << gemmGatedAct::dumpOptions(options) << ", "; - ss << "mBatchedM={}," << std::endl; - ss << "mBatchedN={}," << std::endl; + ss << gemmGatedAct::dumpOptions(options, dumpRuntimeParams) << ", "; + if (dumpRuntimeParams) { + ss << "mBatchedM={}," << std::endl; + ss << "mBatchedN={}," << std::endl; + } ss << "mBatchMode=batchedGemm::BatchedGemmOptions::BatchMode(" << static_cast(options.mBatchMode) << ")," << std::endl; - ss << "mNumBatches=" << options.mNumBatches << "," << std::endl; + ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; + ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; ss << "mIsStaticBatch=" << options.mIsStaticBatch << "," << std::endl; - ss << "mNumTokens=" << options.mNumTokens << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mNumBatches=" << options.mNumBatches << "," << std::endl; + } + ss << "mNumRegsPerThreadLoadB=" << options.mNumRegsPerThreadLoadB << "," << std::endl; + ss << "mNumRegsPerThreadLoadSfB=" << options.mNumRegsPerThreadLoadSfB << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mNumTokens=" << options.mNumTokens << "," << std::endl; + } + ss << "mNumWarpsLoadB=" << options.mNumWarpsLoadB << "," << std::endl; + ss << "mNumWarpsLoadSfB=" << options.mNumWarpsLoadSfB << "," << std::endl; ss << "mRouteImpl=batchedGemm::RouteImpl(" << static_cast(options.mRouteImpl) << ")," << std::endl; ss << "mRouteSfsImpl={batchedGemm::RouteImpl(" << static_cast(options.mRouteSfsImpl.value()) << ")}," << std::endl; - ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; - ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; - ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," - << std::endl; - ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," - << std::endl; - ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl; return ss.str(); } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h index 1086cd4fd5..559118916d 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h @@ -45,6 +45,13 @@ namespace batchedGemm { +namespace trtllm { +namespace gen { +class CudaRunner; +class GenCfg; +} // namespace gen +} // namespace trtllm + namespace gemmGatedAct { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -130,10 +137,6 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& "Unsupported output hidden tile size"); } - if (options.mUseDeepSeekFp8) { - TLLM_CHECK_ERROR(hiddenSize % 256 == 0, "Output hidden size must be a multiple of 256"); - } - if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { int const outHiddenSize = (options.mTransposeMmaOutput ? options.mM : options.mN) / 2; int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); @@ -148,6 +151,21 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& return false; } + auto const validHiddenSize = options.mTransposeMmaOutput ? options.mValidM : options.mValidN; + if (options.mUseDeepSeekFp8) { + TLLM_CHECK_ERROR(hiddenSize % 256 == 0 && validHiddenSize % 256 == 0, "Hidden size (", + hiddenSize, ") and valid hidden size (", validHiddenSize, + ") must be a multiple of 256"); + } + + // + if (options.mUseShuffledMatrixA) { + auto const shuffleBlockSize = gemm::getShuffleBlockSize(options.mEpilogueTileM); + TLLM_CHECK_ERROR( + hiddenSize % (2 * shuffleBlockSize) == 0 && validHiddenSize % (2 * shuffleBlockSize) == 0, + "M/validM must be a multiple of 2 * shuffle block size (", 2 * shuffleBlockSize, + ") when useShuffledMatrixA"); + } if (options.mNumSlicesForSplitK > 1) { TLLM_CHECK_ERROR(doesSplitKUseDsmem(options.mSplitK), "Split-k GMEM and GemmGatedAct are not supported yet."); @@ -163,11 +181,11 @@ inline bool checkAndUpdateGemmGatedActOptions(gemmGatedAct::GemmGatedActOptions& //////////////////////////////////////////////////////////////////////////////////////////////////// -inline std::string dumpOptions(GemmGatedActOptions const& options) { +inline std::string dumpOptions(GemmGatedActOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; - ss << gemm::dumpOptions(options) << ", "; - ss << "mActType=" << "gemmGatedAct::ActType(" << static_cast(options.mActType) << ")," - << std::endl; + ss << gemm::dumpOptions(options, dumpRuntimeParams) << ", "; + ss << "mActType=" + << "gemmGatedAct::ActType(" << static_cast(options.mActType) << ")," << std::endl; ss << "mClampBeforeAct=" << options.mClampBeforeAct << "" << std::endl; return ss.str(); } @@ -179,19 +197,18 @@ inline std::string dumpOptions(GemmGatedActOptions const& options) { //////////////////////////////////////////////////////////////////////////////////////////////////// struct GemmGatedActConfig { - // When TRT-LLM Gen is exported to the other frameworks, the TLLM_GEN_EXPORT_INTERFACE must be - // defined. In this case, the cubins will be loaded from the provided data and function name. - // Otherwise, the kernel will be loaded from the CudaRunner. -#ifdef TLLM_GEN_EXPORT_INTERFACE uint8_t const* mData{nullptr}; - uint32_t const mSize{0}; - uint32_t const mSharedMemSize{0}; + uint32_t mSize{0}; + uint32_t mSharedMemSize{0}; char const* mFunctionName{nullptr}; - uint32_t const mNumThreadsPerCTA{0}; + uint32_t mNumThreadsPerCTA{0}; char const* mHash{nullptr}; -#else + + std::string mGenCfgJsonStr{""}; + char const* mExecPath{nullptr}; trtllm::gen::CudaRunner* mCudaRunner{nullptr}; -#endif + trtllm::gen::GenCfg* mGenCfg{nullptr}; + int32_t mInstanceIdx{0}; GemmGatedActOptions mOptions{}; gemm::SmVersion mSm{gemm::SmVersion::Sm100a}; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index fc3bd88101..af6432f7a0 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -30,7 +30,14 @@ #include "trtllm/gen/CudaRunner.h" #include "trtllm/gen/GenCtx.h" #else +#ifdef TLLM_GEN_EXPORT_FLASHINFER +#include +namespace flashinfer::trtllm_cubin_loader { +std::string getCubin(const std::string& kernelName, const std::string& sha256); +} +#endif // TLLM_GEN_EXPORT_FLASHINFER #include +namespace batchedGemm { template void printArgs(T arg) { @@ -72,7 +79,12 @@ void printArgs(T first, Args... args) { #endif // TLLM_GEN_EXPORT_INTERFACE -namespace batchedGemm { +namespace trtllm { +namespace gen { +class CudaRunner; +class GenCfg; +} // namespace gen +} // namespace trtllm namespace gemm { @@ -91,28 +103,29 @@ struct GemmOptions { #endif GemmOptions() = default; - GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, - int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, - bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, - bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA, - MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, - bool mockAllReduce, int n, int numRegsCopySfLdsSttm, int numSlicesForSplitK, - int numSlicesForSliceK, int numStages, int numStagesMma, - int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, - bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, - int sfReshapeFactor, bool sliceK, SplitK splitK, int tileK, int tileM, int tileN, - TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, - bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, + int epilogueTileM, int epilogueTileN, bool fuseUtccpWithUtcmma, + bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, + MatrixLayout layoutA, MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, + int mmaM, int mmaN, bool mockAllReduce, int n, int numEpilogueWarps, + int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, + int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, + bool patchF2fp, std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, + tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int sfReshapeFactor, bool sliceK, + SplitK splitK, int tileK, int tileM, int tileN, TileScheduler tileScheduler, + bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, + bool useHoistTryWaitForCustomMmaSchedule, bool useMaxTmemOverlap, bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, - bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, - int worldSize) + bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int validM, + int validN, int validK, int worldSize) : mAllReduceAlgo{allReduceAlgo}, mBiasType{biasType}, mBlockK(blockK), @@ -133,6 +146,7 @@ struct GemmOptions { mEpilogueLdtmBits{epilogueLdtmBits}, mEpilogueTileM{epilogueTileM}, mEpilogueTileN{epilogueTileN}, + mFuseUtccpWithUtcmma{fuseUtccpWithUtcmma}, mGridTriggerSecondaryA{gridTriggerSecondaryA}, mGridTriggerSecondaryB{gridTriggerSecondaryB}, mGridWaitForPrimaryEarlyExit{gridWaitForPrimaryEarlyExit}, @@ -151,7 +165,11 @@ struct GemmOptions { mMmaN{mmaN}, mMockAllReduce{mockAllReduce}, mN{n}, + mNumEpilogueWarps{numEpilogueWarps}, + mNumRegsCastAWarps(numRegsCastAWarps), mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm), + mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), + mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), mNumSlicesForSplitK{numSlicesForSplitK}, mNumSlicesForSliceK{numSlicesForSliceK}, mNumStages{numStages}, @@ -176,6 +194,7 @@ struct GemmOptions { mUseCustomMmaSchedule{useCustomMmaSchedule}, mUseDeepSeekFp8{useDeepSeekFp8}, mUseHoistTryWaitForCustomMmaSchedule{useHoistTryWaitForCustomMmaSchedule}, + mUseMaxTmemOverlap{useMaxTmemOverlap}, mUsePerTokenSfA{usePerTokenSfA}, mUsePerTokenSfB{usePerTokenSfB}, mUseShuffledMatrixA{useShuffledMatrixA}, @@ -183,6 +202,9 @@ struct GemmOptions { mUseTwoTmaLoadWarps{useTwoTmaLoadWarps}, mUseTwoMmaWarps{useTwoMmaWarps}, mUseUnrollLoop2xForMma{useUnrollLoop2xForMma}, + mValidM{validM}, + mValidN{validN}, + mValidK{validK}, mWorldSize{worldSize} {} // The all-reduce algorithm. @@ -233,6 +255,8 @@ struct GemmOptions { int mEpilogueTileM{128}; // Tile size for the epilogue in N dimension. int mEpilogueTileN{32}; + // Whether fuse UTCCP with UTC*MMA. + bool mFuseUtccpWithUtcmma{false}; // Whether load task A triggers the next grid. bool mGridTriggerSecondaryA{false}; // Whether load task B triggers the next grid. @@ -269,8 +293,16 @@ struct GemmOptions { bool mMockAllReduce{false}; // The N dimension of GEMM. int mN{64 * 4}; + // Number of Epilogue Warps + int mNumEpilogueWarps{4}; + // Number of registers for the cast A warps. + int mNumRegsCastAWarps{0}; // Number of registers for the LDS+STTM warps. int mNumRegsCopySfLdsSttm{0}; + // Number of registers per thread for epilogue warps + int mNumRegsPerThreadEpilogueWarp{0}; + // Number of registers per thread for non-epilogue warps + int mNumRegsPerThreadNonEpilogueWarp{0}; // Number of partitions along the K dimension. When mNumSlicesForSplitK > 1, // the problem is distributed across several SMs, where each CTA works on its local K slice. // Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA) @@ -329,6 +361,8 @@ struct GemmOptions { // k-block. It benefits when the next k-block is already available and thus sustaining the // momentum, but it adds latency to the first k-block for smaller k-loop. bool mUseHoistTryWaitForCustomMmaSchedule{false}; + // Whether use the max Tmem overlap trick. + bool mUseMaxTmemOverlap{false}; // Apply per-token scales from A bool mUsePerTokenSfA{false}; // Apply per-token scales from B @@ -343,6 +377,15 @@ struct GemmOptions { bool mUseTwoMmaWarps{false}; // Whether to unroll the loop by 2x. bool mUseUnrollLoop2xForMma{true}; + // The valid range of M/N/K dimension of GEMM without padding values. + // Used to opportunistically remove memory traffic from the padding due to rigid SF shape + // constraint or TMA constraint. Such as: + // 1. outputDim % (4 * sfBlockSize) == 0; as 4x SFs are packed into 4 bytes + // 2. MxFp4 x Fp8 mmaType requires bespoke TMA load which requires hiddenDim % 128 == 0 + // 3. TMA requires 16B alignment for each row + int mValidM{-1}; + int mValidN{-1}; + int mValidK{-1}; // World size for all-reduce. int mWorldSize{1}; }; @@ -365,19 +408,17 @@ inline bool isSmVersionBlackwell(SmVersion smVersion) { //////////////////////////////////////////////////////////////////////////////////////////////////// struct GemmConfig { - // When TRT-LLM Gen is exported to the other frameworks, the TLLM_GEN_EXPORT_INTERFACE must be - // defined. In this case, the cubins will be loaded from the provided data and function name. - // Otherwise, the kernel will be loaded from the CudaRunner. -#ifdef TLLM_GEN_EXPORT_INTERFACE uint8_t const* mData{nullptr}; - uint32_t const mSize{0}; - uint32_t const mSharedMemSize{0}; + uint32_t mSize{0}; + uint32_t mSharedMemSize{0}; char const* mFunctionName{nullptr}; - uint32_t const mNumThreadsPerCTA{0}; + uint32_t mNumThreadsPerCTA{0}; char const* mHash{nullptr}; -#else + std::string mGenCfgJsonStr{""}; + char const* mExecPath{nullptr}; trtllm::gen::CudaRunner* mCudaRunner{nullptr}; -#endif + trtllm::gen::GenCfg* mGenCfg{nullptr}; + int32_t mInstanceIdx{0}; GemmOptions mOptions{}; SmVersion mSm{SmVersion::Sm100a}; @@ -407,7 +448,7 @@ inline std::string toString(trtllm::gen::MmaKind e) { //////////////////////////////////////////////////////////////////////////////////////////////////// -inline std::string dumpOptions(GemmOptions const& options) { +inline std::string dumpOptions(GemmOptions const& options, bool dumpRuntimeParams = true) { std::stringstream ss; ss << "mAllReduceAlgo=" << "gemm::AllReduceAlgo(" << static_cast(options.mAllReduceAlgo) << ")" @@ -447,6 +488,7 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mEpilogueLdtmBits=" << options.mEpilogueLdtmBits << "," << std::endl; ss << "mEpilogueTileM=" << options.mEpilogueTileM << "," << std::endl; ss << "mEpilogueTileN=" << options.mEpilogueTileN << "," << std::endl; + ss << "mFuseUtccpWithUtcmma=" << options.mFuseUtccpWithUtcmma << "," << std::endl; ss << "mGridTriggerSecondaryA=" << options.mGridTriggerSecondaryA << "," << std::endl; ss << "mGridTriggerSecondaryB=" << options.mGridTriggerSecondaryB << "," << std::endl; ss << "mGridWaitForPrimaryEarlyExit=" << options.mGridWaitForPrimaryEarlyExit << "," << std::endl; @@ -454,14 +496,18 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mGridWaitForPrimaryB=" << options.mGridWaitForPrimaryB << "," << std::endl; ss << "mHoistLoadTaskInit=" << options.mHoistLoadTaskInit << "," << std::endl; ss << "mHoistMmaTaskTryWaits=" << options.mHoistMmaTaskTryWaits << "," << std::endl; - ss << "mK=" << options.mK << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mK=" << options.mK << "," << std::endl; + } ss << "mKernelTraits={}" << "," << std::endl; ss << "mLayoutA=gemm::MatrixLayout(" << static_cast(options.mLayoutA) << ")" << "," << std::endl; ss << "mLayoutB=gemm::MatrixLayout(" << static_cast(options.mLayoutB) << ")" << "," << std::endl; - ss << "mM=" << options.mM << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mM=" << options.mM << "," << std::endl; + } ss << "mMmaK=" << options.mMmaK << "," << std::endl; ss << "mMmaKind=" << "trtllm::gen::MmaKind(" << static_cast(options.mMmaKind) << ")" @@ -469,8 +515,16 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mMmaM=" << options.mMmaM << "," << std::endl; ss << "mMmaN=" << options.mMmaN << "," << std::endl; ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl; - ss << "mN=" << options.mN << "," << std::endl; + if (dumpRuntimeParams) { + ss << "mN=" << options.mN << "," << std::endl; + } + ss << "mNumEpilogueWarps=" << options.mNumEpilogueWarps << "," << std::endl; + ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl; + ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," + << std::endl; + ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," + << std::endl; ss << "mNumSlicesForSplitK=" << options.mNumSlicesForSplitK << "," << std::endl; ss << "mNumSlicesForSliceK=" << options.mNumSlicesForSliceK << "," << std::endl; ss << "mNumStages=" << options.mNumStages << "," << std::endl; @@ -512,6 +566,7 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mUseDeepSeekFp8=" << options.mUseDeepSeekFp8 << "," << std::endl; ss << "mUseHoistTryWaitForCustomMmaSchedule=" << options.mUseHoistTryWaitForCustomMmaSchedule << "," << std::endl; + ss << "mUseMaxTmemOverlap=" << options.mUseMaxTmemOverlap << "," << std::endl; ss << "mUsePerTokenSfA=" << options.mUsePerTokenSfA << "," << std::endl; ss << "mUsePerTokenSfB=" << options.mUsePerTokenSfB << "," << std::endl; ss << "mUseShuffledMatrixA=" << options.mUseShuffledMatrixA << "," << std::endl; @@ -519,7 +574,12 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mUseTwoTmaLoadWarps=" << options.mUseTwoTmaLoadWarps << "," << std::endl; ss << "mUseTwoMmaWarps=" << options.mUseTwoMmaWarps << "," << std::endl; ss << "mUseUnrollLoop2xForMma=" << options.mUseUnrollLoop2xForMma << "," << std::endl; - ss << "mWorldSize=" << options.mWorldSize << std::endl; + if (dumpRuntimeParams) { + ss << "mValidM=" << options.mValidM << "," << std::endl; + ss << "mValidN=" << options.mValidN << "," << std::endl; + ss << "mValidK=" << options.mValidK << "," << std::endl; + ss << "mWorldSize=" << options.mWorldSize << std::endl; + } return ss.str(); } @@ -578,6 +638,51 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } + // If validM/N/K is not specified, then assume the full range of the dimension is valid. + if (options.mValidM == -1) { + options.mValidM = options.mM; + } + if (options.mValidN == -1) { + options.mValidN = options.mN; + } + if (options.mValidK == -1) { + options.mValidK = options.mK; + } + + // It must not exceed the padded dimensions. + if (options.mValidM > options.mM || options.mValidN > options.mN || + options.mValidK > options.mK) { + TLLM_LOG_WARNING( + options.mValidK <= options.mK, + "ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively."); + if (updateOptions) { + options.mValidM = std::min(options.mValidM, options.mM); + options.mValidN = std::min(options.mValidN, options.mN); + options.mValidK = std::min(options.mValidK, options.mK); + } else { + return false; + } + } + + // BlockMajorK layout does not support validM, validN, validK parameters + if (options.mLayoutA == gemm::MatrixLayout::BlockMajorK || + options.mLayoutB == gemm::MatrixLayout::BlockMajorK) { + bool hasValidParams = (options.mValidM != -1 && options.mValidM != options.mM) || + (options.mValidN != -1 && options.mValidN != options.mN) || + (options.mValidK != -1 && options.mValidK != options.mK); + TLLM_CHECK_ERROR(!hasValidParams, + "BlockMajorK layout does not support validM/validN/validK parameters due to " + "swizzled layout. " + "Found validM=", + options.mValidM, " validN=", options.mValidN, " validK=", options.mValidK); + } + +#ifdef TLLM_PUBLIC_RELEASE + if (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3) { + TLLM_CHECK_ERROR(false, "E2m1 x E4m3 is not supported for JIT compile. Use cubins instead."); + } +#endif // TLLM_PUBLIC_RELEASE + // Check that the A cast is supported. // Currently, we only support {MxFp4, NvFp4} -> Bf16. TLLM_CHECK_ERROR( @@ -607,6 +712,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mDtypeA == tg::Dtype::MxE2m1 && options.mDtypeMmaA == tg::Dtype::Bfloat16, "PatchF2fp is only supported for MxFp4 to Bf16 casts."); } +#ifdef TLLM_PUBLIC_RELEASE + options.mPatchF2fp = false; +#endif // TLLM_PUBLIC_RELEASE // FIXME: We do not support different dtypes for A and B when not on Blackwell. if (!isBlackwell) { @@ -819,7 +927,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in (padMultiplierB * tg::dtypeGetNumBits(options.mDtypeB) * options.mK / 8) % 16 == 0, "K dimension of B must be aligned to 16 bytes."); - if (options.mDtypeC == tg::Dtype::E2m1 || options.mDtypeC == tg::Dtype::MxE4m3) { + if (tg::dtypeIsBlockFmt(options.mDtypeC)) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); TLLM_CHECK_ERROR( @@ -836,6 +944,10 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in int const hiddenGranularity = 4 * tg::dtypeNumEltsPerSf(options.mDtypeC); TLLM_CHECK_ERROR(hiddenDim % hiddenGranularity == 0, "Hidden dim (", hiddenDim, ") must be a multiple of ", hiddenGranularity, " for block-scaled outputs."); + int const validHiddenDim = options.mTransposeMmaOutput ? options.mValidM : options.mValidN; + TLLM_CHECK_ERROR(validHiddenDim % tg::dtypeNumEltsPerSf(options.mDtypeC) == 0, + "Valid hidden dim (", validHiddenDim, ") must be a multiple of ", + tg::dtypeNumEltsPerSf(options.mDtypeC), " for block-scaled outputs."); TLLM_CHECK_ERROR(!options.mTransposeMmaOutput || options.mUseShuffledMatrixA, "Transposing block-scaled outputs requires shuffled A."); } @@ -901,8 +1013,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if (options.mUseShuffledMatrixA) { auto const shuffleBlockSize = getShuffleBlockSize(options.mEpilogueTileM); - TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0, - "M must be a multiple of shuffle block size (", shuffleBlockSize, + TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0 && options.mValidM % shuffleBlockSize == 0, + "M/validM must be a multiple of shuffle block size (", shuffleBlockSize, ") when useShuffledMatrixA"); } @@ -1084,9 +1196,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in // options.mUseTwoMmaWarps = true; // Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8. - TLLM_CHECK_ERROR(options.mK % 128 == 0, - "GEMM-K must be a multiple of 128 when using DeepSeek Fp8. Found ", - options.mK); + TLLM_CHECK_ERROR(options.mK % 128 == 0 && options.mValidK % 128 == 0, + "GEMM-K and validK must be a multiple of 128 when using DeepSeek Fp8. Found ", + options.mK, " and validK=", options.mValidK); // Check that the output tile N can be processed with the epilogue tile granularity. TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerEpilogueTile == 0, @@ -1100,6 +1212,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in hiddenDimPerMma, ")"); } + TLLM_CHECK_ERROR(options.mNumEpilogueWarps == 4 || options.mNumEpilogueWarps == 8, + "mNumEpilogueWarps has to be either 4 or 8."); + if (options.mSliceK) { TLLM_CHECK_ERROR(isBlackwell, "Slice-K is not supported on Hopper"); @@ -1253,7 +1368,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "At least one matrix must be in k-major layout"); // Some features are currently only support when both matrices are in K-major format - if (options.mLayoutB != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { + if (options.mLayoutA != MatrixLayout::MajorK || options.mLayoutB != MatrixLayout::MajorK) { TLLM_CHECK_ERROR(isBlackwell, "Non K-major layouts are only supported on Blackwell"); TLLM_CHECK_ERROR(options.mSplitK == SplitK::None, "Non K-major layouts do not support split K"); } @@ -1303,6 +1418,31 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Bias is not supported for Meta Fp8"); } + if (options.mUseMaxTmemOverlap) { + TLLM_CHECK_ERROR(options.mUseTmaStore, "mUseMaxTmemOverlap only works with TMA store"); + TLLM_CHECK_ERROR(options.mFuseUtccpWithUtcmma, + "mUseMaxTmemOverlap only works with mFuseUtccpWithUtcmma"); + TLLM_CHECK_ERROR(options.mNumSlicesForSplitK == 1, + "mUseMaxTmemOverlap does not work with splitK"); + TLLM_CHECK_ERROR(options.mNumSlicesForSliceK == 1, + "mUseMaxTmemOverlap does not work with sliceK"); + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, + "mUseMaxTmemOverlap does not work with mUseDeepSeekFp8"); + TLLM_CHECK_ERROR(!options.mUseUnrollLoop2xForMma, + "mUseMaxTmemOverlap does not work with mUseUnrollLoop2xForMma"); + } + + if (options.mNumEpilogueWarps > 4) { + TLLM_CHECK_ERROR(options.mUseTmaStore, + "Using more than 4 warps for epilogue only works with TMA store"); + TLLM_CHECK_ERROR(options.mNumSlicesForSplitK == 1, + "Using more than 4 warps for epilogue does not work with splitK"); + TLLM_CHECK_ERROR(options.mNumSlicesForSliceK == 1, + "Using more than 4 warps for epilogue does not work with sliceK"); + TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, + "Using more than 4 warps for epilogue does not work with mUseDeepSeekFp8"); + } + if (updateOptions) { // Init kernel traits. options.mKernelTraits = KernelTraits( @@ -1311,6 +1451,7 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mTileK, options.mEpilogueTileM, options.mEpilogueTileN, options.mNumStages, options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, + options.mFuseUtccpWithUtcmma, options.mUseMaxTmemOverlap, options.mNumEpilogueWarps, options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, options.mUsePerTokenSfA, options.mUsePerTokenSfB, /* useTwoCtas*/ options.mClusterDimX == 2, options.mBiasType); @@ -1321,6 +1462,59 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in //////////////////////////////////////////////////////////////////////////////////////////////////// +inline bool getDoesScaleC(tg::Dtype dtypeC) { + // Need to scale/quantize the output C matrix when the output type is Fp8 or NvFp4. + return dtypeC == tg::Dtype::E4m3 || dtypeC == tg::Dtype::E2m1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getDoesScaleAb(tg::Dtype dtypeA, tg::Dtype dtypeB, bool useDeepSeekFp8) { + // Need to scale/dequantize the input A/B matrices when the input type is Fp8 or NvFp4 and + // DeepSeekFp8 is not used. + bool const doesScaleAb{ + dtypeA == tg::Dtype::E2m1 || dtypeB == tg::Dtype::E2m1 || + ((dtypeA == tg::Dtype::E4m3 || dtypeB == tg::Dtype::E4m3) && !useDeepSeekFp8)}; + return doesScaleAb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline bool getKernelDoesScaleC(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, + bool useDeepSeekFp8) { + // In the Gemm/BatchedGemm kernels, dequantScaleAb and quantScaleC are combined into one single + // scaling factor (called scaleC). As a result, we combine the logic for getDoesScaleAb and + // getDoesScaleC. + return getDoesScaleC(dtypeC) || getDoesScaleAb(dtypeA, dtypeB, useDeepSeekFp8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline CUresult loadCubinData(CUmodule* module, Config const& config) { + // Trtllm links the cubin into the executable while Flashinfer loads the cubin from storage. +#ifdef TLLM_GEN_EXPORT_FLASHINFER +#ifdef TLLM_GEN_GEMM_CUBIN_PATH + static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH); + const std::string sha256 = config.mHash ? config.mHash : ""; + std::string fileName = config.mFunctionName; + if (!fileName.empty()) { + fileName[0] = static_cast(std::toupper(static_cast(fileName[0]))); + } + const std::string& data = flashinfer::trtllm_cubin_loader::getCubin( + tllm_gen_gemm_cubin_path + "/" + fileName + ".cubin", sha256); + CUresult result = cuModuleLoadData(module, data.c_str()); +#else + static_assert(false, "TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling"); +#endif // TLLM_GEN_GEMM_CUBIN_PATH +#else + CUresult result = cuModuleLoadData(module, config.mData); +#endif // TLLM_GEN_EXPORT_FLASHINFER + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm #ifdef TLLM_GEN_EXPORT_INTERFACE diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index 7e0474bb5f..800c8546ef 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -82,8 +82,18 @@ bool useTmaOobOptC(BatchedGemmOptions const& options) { // Create the TMA shape/stride for A/B/C. template -static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, int mK, int tileM, - int tileN, int tileK, MatrixType matrixType) { +static auto makeTmaShapeStrideAbc(GemmOptions const& options, int sizeM, int sizeN, int sizeK, + int tileM, int tileN, int tileK, MatrixType matrixType, + int validM = -1, int validN = -1, int validK = -1) { + if (validM == -1) { + validM = sizeM; + } + if (validN == -1) { + validN = sizeN; + } + if (validK == -1) { + validK = sizeK; + } // Weights matrix is A if we transpose the output of MMA (to have it M-major). // Otherwise, it is B, when the output of MMA is K-major. bool const isWeights = (matrixType == MatrixType::MatrixA && options.mTransposeMmaOutput) || @@ -96,9 +106,11 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in : matrixType == MatrixType::MatrixC ? useTmaOobOptC(options) : false; - // The outer dimension. + // The outer dimension. Uses padded dimensions for strides and valid dimensions for shapes. auto numTokens = - (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? mM : mN; + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? sizeM : sizeN; + auto numTokensValid = + (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? validM : validN; // The outer dimension tile size. auto ctaTileNumTokens = (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? tileM : tileN; @@ -107,7 +119,8 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM : ctaTileNumTokens; // The inner dimension. - auto hiddenSize = (matrixType == MatrixType::MatrixC) ? mN : mK; + auto hiddenSize = (matrixType == MatrixType::MatrixC) ? sizeN : sizeK; + auto hiddenSizeValid = (matrixType == MatrixType::MatrixC) ? validN : validK; // The inner dimension tile size. auto ctaTileHiddenSize = (matrixType == MatrixType::MatrixC) ? tileN : tileK; // The inner dimension of TMA box shape. @@ -117,6 +130,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // Swap matrix C sizes if output is transposed. if (matrixType == MatrixType::MatrixC && options.mTransposeMmaOutput) { std::swap(numTokens, hiddenSize); + std::swap(numTokensValid, hiddenSizeValid); std::swap(ctaTileNumTokens, ctaTileHiddenSize); std::swap(tileNumTokens, tileHiddenSize); } @@ -125,6 +139,7 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // gated activations but not regular activations. if (options.mFusedAct && matrixType == MatrixType::MatrixC) { hiddenSize /= 2; + hiddenSizeValid /= 2; tileHiddenSize /= 2; ctaTileHiddenSize /= 2; } @@ -134,17 +149,18 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // 1, so swap the first two dimension so that the hiddenSize dimension comes first. // Activations matrix is 2D (sum(divUpMul(M[bi], tileM) for bi in B), K). - std::vector shape = {static_cast(hiddenSize), - static_cast(numTokens)}; + // Use valid dimensions for shape. + std::vector shape = {static_cast(hiddenSizeValid), + static_cast(numTokensValid)}; if (useTmaOobOpt /* also implies input/output activation */) { // If TMA OOB optimization is used: // Shape [hidden, tokens] Stride [1, hidden] becomes // Shape [hidden, tileN, TmaDimMax, TmaDimMax] Stride [1, hidden, XLargeN - hidden, hidden] - shape = {static_cast(hiddenSize), static_cast(ctaTileNumTokens), + shape = {static_cast(hiddenSizeValid), static_cast(ctaTileNumTokens), static_cast(tg::TmaDimMax), static_cast(tg::TmaDimMax)}; } else if (isWeights) { // If the matrix is a weights matrix, we use 3D logical shape (B, M, K) or (B, N, K). - shape = {static_cast(hiddenSize), static_cast(numTokens), + shape = {static_cast(hiddenSizeValid), static_cast(numTokensValid), static_cast(options.mNumBatches)}; } @@ -177,10 +193,11 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in stride[1] = numTokens; std::swap(tileShape[0], tileShape[1]); } else if (layout == gemm::MatrixLayout::BlockMajorK) { - // Set shapes based on blocking layout + // Set shapes based on blocking layout. shape = {static_cast(options.mBlockK), static_cast(numTokens), - static_cast(mK / options.mBlockK), + static_cast(sizeK / options.mBlockK), static_cast(options.mNumBatches)}; + // Strides use padded dimensions stride = {1, static_cast(options.mBlockK), static_cast(numTokens * options.mBlockK), static_cast(hiddenSize * numTokens)}; @@ -209,17 +226,6 @@ static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType switch (layout) { case tg::SfLayout::R128c4: { - // The scaling factor tensor packs 128x4 tiles into contiguous 512B blocks. - // The 512B block maps to a 32x16B (32x128b) block in TMEM. - // See https://nvbugspro.nvidia.com/bug/4165523 - // - // Additionally, we have to meet constraints of TMA that the box dimensions are less - // than 256 and boxDim[0] is a multiple of 16B. - // - // The "logical" tensor is: [outer, inner / numEltsPerSf] - // The aforementioned format is: [outer / 128, inner / numEltsPerSf / 4, 512] - // The shape we use for TMA is: [outer / 128, inner / numEltsPerSf / 4, 2, 256] - auto shape = std::vector{ 256, 2, static_cast(ceilDiv(hiddenSize, numEltsPerSf * 4)), static_cast(ceilDiv(numTokens, 128))}; @@ -294,7 +300,6 @@ static auto makeTmaShapeStrideSfAb(int mM, int mN, int mK, MatrixType matrixType } return std::make_tuple(std::vector{}, std::vector{}, std::vector{}); } - template static KernelParams setKernelParams( GemmOptions_ const& options, bool const batchM, void const* ptrA, void const* ptrB, void* ptrC, @@ -390,9 +395,9 @@ static KernelParams setKernelParams( params.tileStridePerBatch = options.mM / options.mTileM; params.nm = options.mM; // Shape/stride for gmem tensor A. - auto [shapeA, strideA, tileShapeA] = - makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixA); + auto [shapeA, strideA, tileShapeA] = makeTmaShapeStrideAbc( + options, options.mM, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, + MatrixType::MatrixA, options.mValidM, options.mValidN, options.mValidK); // Build tma descriptor for A. params.tmaA[0] = gemm::buildNdTmaDescriptor(options.mDtypeA, options.mMmaKind, shapeA, strideA, tileShapeA, const_cast(ptrA)); @@ -469,15 +474,17 @@ static KernelParams setKernelParams( // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( - options, options.mM, ctaOffset * options.mTileN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = + makeTmaShapeStrideAbc(options, options.mM, ctaOffset * options.mTileN, options.mK, + options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC, + options.mValidM, ctaOffset * options.mTileN, options.mValidK); // Build tma descriptor for C. params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, strideC, tileShapeC, ptrC); } else { params.ptrC = ptrC; } + } else { // B is the expert if (0 != options.mN % options.mTileN) { @@ -486,9 +493,9 @@ static KernelParams setKernelParams( params.tileStridePerBatch = options.mN / options.mTileN; params.nm = options.mN; // Shape/stride for gmem tensor B. - auto [shapeB, strideB, tileShapeB] = - makeTmaShapeStrideAbc(options, options.mM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixB); + auto [shapeB, strideB, tileShapeB] = makeTmaShapeStrideAbc( + options, options.mM, options.mN, options.mK, options.mTileM, options.mTileN, options.mTileK, + MatrixType::MatrixB, options.mValidM, options.mValidN, options.mValidK); // Build tma descriptor for B. params.tmaB[0] = gemm::buildNdTmaDescriptor(options.mDtypeB, options.mMmaKind, shapeB, strideB, tileShapeB, const_cast(ptrB)); @@ -544,9 +551,10 @@ static KernelParams setKernelParams( // C is the output activation if (options.mUseTmaStore) { // Shape/stride for gmem tensor C. - auto [shapeC, strideC, tileShapeC] = makeTmaShapeStrideAbc( - options, ctaOffset * options.mTileM, options.mN, options.mK, options.mTileM, - options.mTileN, options.mTileK, MatrixType::MatrixC); + auto [shapeC, strideC, tileShapeC] = + makeTmaShapeStrideAbc(options, ctaOffset * options.mTileM, options.mN, options.mK, + options.mTileM, options.mTileN, options.mTileK, MatrixType::MatrixC, + ctaOffset * options.mTileM, options.mValidN, options.mValidK); // Build tma descriptor for C. params.tmaC[0] = gemm::buildNdTmaDescriptor(options.mDtypeC, tg::MmaKind::Auto, shapeC, strideC, tileShapeC, ptrC); diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h index 16b4af3149..e11374739f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h @@ -29,54 +29,6 @@ struct KernelParams { // Maximum number of CTAs in the batch-token dimension. static constexpr int MaxNumCtas = 2048; - // NOTE: TMA out-of-bounds optimization for MoE padded tokens: - // - // Originally the padded tokens is a 2D tensor [hiddenDim, ctaGridDimY * tileN] with stride [1, - // hiddenDim] and box size [tileM, tileN] at pointer p. We waste bandwidth bytes since we only - // want to load [0, batchEnd) out of the [0, tileN) box size: batchEnd is a runtime variable while - // box size needs to be fixed at compile time. - // - // To deal with this, we reshape the tensor to 3D: [hiddenDim, tileN, ctaGridDimY * tileN] with - // stride [1, hiddenDim, hiddenDim] and box size [tileM, tileN, 1]. For the original 2D - // tensor, - // - // Offset Coords [ : , ctaIdxY * tileN ], - // Box Sizes [ : , tileN ], - // Coords Range [ : , ctaIdxY * tileN : ctaIdxY * tileN + tileN], - // - // while we only want load the range [ctaIdxY * tileN, ctaIdxY * tileN + batchEnd), 1 <= batchEnd - // <= tileN - // - // For the reshaped 3D tensor, - // - // Offset Coords [ : , tileN - batchEnd , - // ctaIdxY * tileN + batchEnd ], - // Box Sizes [ : , tileN , - // 1 ], - // Coords Range [ : , tileN - batchEnd : min(tileN, 2 * tileN - batchEnd), - // ctaIdxY * tileN + batchEnd : ctaIdx * tileN + batchEnd + 1], - // - // while min(tileN, 2 * tileN - batchEnd) always evaluates to tileN. The unwanted tokens are - // essentially filtered out by utilizing the OOB feature of TMA. Since the 2nd and 3rd dimension - // has the same stride, we end up loading the following (adding the left and right end of the 2nd - // and 3rd dimension ranges): - // - // Effective 2D Coords Range - // [ : , tileN + ctaIdxY * tileN : tileN + ctaIdxY * tileN + batchEnd], - // - // This is exactly the same as the original range except for the offset tileN, thus we also need - // to offset the pointer in the opposite direction: - // - // Ptr (p) -> Ptr (p - tileN * hiddenDim) - // - // Due to the restrictions of TMA unit, the above operations requires the TMA descriptor and the - // underlying buffer be constructed differently: - // - Requires valid buffer at (p - tileN * hidden) - needs prepending `tileN` tokens. - // - TMA outermost dimension must be extended by `tileN` or loads will OOB in the rightmost side. - // The latter is because when batchEnd == tileN, the offset coords in the 3rd dimension becomes - // ctaIdxY * tileN + tileN. When ctaIdxY = ctaGridDimY - 1, it becomes ((ctaGridDimY - 1) * tileN - // + tileN = ctaGridDimY * tileN which is equal to the 3rd dimension size and will be filtered - // out. That's why we need to extend the tensor size by tileN. // // TMA descriptor for A. // Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index 4d79f83c23..4ea0a91250 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include "Enums.h" @@ -162,9 +163,12 @@ class KernelTraits { int32_t epilogueTileN, int32_t numStages, int32_t numStagesMma, int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, + bool fuseUtccpWithUtcmma, bool useMaxTmemOverlap, int32_t numEpilogueWarps, bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) - : mMmaKind{mmaKind} { + : mMmaKind{mmaKind}, + mFuseUtccpWithUtcmma{fuseUtccpWithUtcmma}, + mUseMaxTmemOverlap{useMaxTmemOverlap} { // // SMEM // @@ -271,6 +275,10 @@ class KernelTraits { extraGmemCMultiplier = 0; } + if (numEpilogueWarps) { + extraGmemCMultiplier *= numEpilogueWarps / 4; + } + // Number of bytes to store the output in smem. auto const numBytesSmemStoreC = usesSmemForGmemC ? extraGmemCMultiplier * epilogueTileM * epilogueTileN * @@ -418,8 +426,11 @@ class KernelTraits { std::vector tmemChunkNames; // Matrix D { + // Two set of TMEM resources for D share epilogueTileN columns, + // | set0:epiTileN0 | set0:epiTileN1/set1:epiTileN0 | set1:epiTileN1 | + auto const numCols = mUseMaxTmemOverlap ? 2 * tileN - epilogueTileN : tileN; // Number of columns for accumulators. - auto const numTmemColsD = numSlicesForSliceK * tileN * numStagesMma * + auto const numTmemColsD = numSlicesForSliceK * numCols * numStagesMma * tg::dtypeGetNumBits(dtypeAcc) / tg::dtypeGetNumBits(tg::Dtype::UInt32); // Number of columns for D alignment. @@ -466,9 +477,9 @@ class KernelTraits { auto const numTmemColsSfA = useConstSfA ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK), 4) - : (useBlockScalingA - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * numStages - : 0); + : (useBlockScalingA ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileM, mmaK)) * + (mFuseUtccpWithUtcmma ? 1 : numStages) + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfA = 4; // No need to reuse TMEM. @@ -491,9 +502,9 @@ class KernelTraits { auto const numTmemColsSfB = useConstSfB ? tg::roundUp((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK), 4) - : (useBlockScalingB - ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * numStages - : 0); + : (useBlockScalingB ? ((tileK / 64) * tg::getTmemColStridePerGroup(tileN, mmaK)) * + (mFuseUtccpWithUtcmma ? 1 : numStages) + : 0); // Number of columns for Sf alignment. auto const numColsAlignmentSfB = 4; // No need to reuse TMEM. @@ -515,6 +526,10 @@ class KernelTraits { public: // The MMA kind. tg::MmaKind mMmaKind; + // Whether fuse Utccp into the MMA task. + bool mFuseUtccpWithUtcmma; + // Whether use the max TMEM overlap trick. + bool mUseMaxTmemOverlap; // Helper for SMEM allocation. MemAllocatorHelper mSmemAllocatorHelper; // Helper for TMEM allocation. diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h index a1412444ae..c7b18af138 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h @@ -156,7 +156,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + ss << "Error: Failed to initialize the TMA descriptor. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; @@ -251,7 +251,7 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl; + ss << "Error: Failed to initialize the TMA descriptor for SF. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h index c7f1020dea..53155c8ffb 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h @@ -38,8 +38,6 @@ constexpr unsigned long XLargeN = 1UL << 35; //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h index 965bb1b7b8..56b537ff42 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h @@ -63,8 +63,6 @@ enum class SfLayout { // | 1,0 | 1,1 | 1,2 | 1,3 | 33,0 | 33,1 | 33,2 | 33,3 | ... | 97,3 | // | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | // | 31,0 | 31,1 | 31,2 | 31,3 | 63,0 | 63,1 | 63,2 | 63,3 | ... | 127,3 | - // See https://nvbugspro.nvidia.com/bug/4165523 - // // I.e., the SF buffer is a tensor [⌈m/128⌉, ⌈n/b/4⌉, 32, 4, 4] // The SF for the element (i, j) is stored at (i/128, j/b/4, i%32, (i%128)/32, (j/b)%4). R128c4, diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 50d3baecc7..e3a0d21884 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -113,58 +113,67 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported pair"); \ } +#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mPaddingLog2 > 0) { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, \ + stream); \ + } else { \ + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, false), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } + #define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), kernel, \ - numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \ - numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, extraFlag, numExperts) \ - if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ +#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag, numExperts) \ + if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeExpW"); \ } #define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ @@ -182,17 +191,17 @@ namespace moe::dev { #define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ stream, extraFlag1, numExperts) \ if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ + numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ + numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported dtypeExpW"); \ } diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh index dd7d5c474d..d110037269 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh @@ -67,6 +67,24 @@ __host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) { return mulLog2(divUpLog2(a, bLog2), bLog2); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__host__ __device__ constexpr T mulTileN(T a, T tileN) { + return a * tileN; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__host__ __device__ constexpr T divUpTileN(T a, T tileN) { + return (a + tileN - 1) / tileN; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__host__ __device__ constexpr T divUpMulTileN(T a, T tileN) { + return divUpTileN(a, tileN) * tileN; +} + //////////////////////////////////////////////////////////////////////////////////////////////////// __host__ __device__ constexpr int32_t getBits(int32_t value, int idx) { @@ -299,7 +317,14 @@ __device__ void routingPermutation(KernelParams params, // Compute the runtime config for projections // Whether or not an expert is local is taken into account when smemExpertCount is computed // so we do not need to take it into account here. - const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } + int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); @@ -310,21 +335,37 @@ __device__ void routingPermutation(KernelParams params, const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } // get the padded offset associated with this expert - const int32_t offset = mulLog2(ctaOffset, params.mPaddingLog2); - + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } // write expert offsets to shared smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; } // write out padded count if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -513,14 +554,25 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // Compute the runtime config for projections // Whether or not an expert is local is taken into account when the histogram is computed // so we do not need to take it into account here. - const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + // const int32_t numCta = divUpLog2(count, params.mPaddingLog2); + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); if (threadIdx.x < params.mNumExperts) { // Get the padded offset associated with this expert - const int32_t offset = mulLog2(ctaOffset, params.mPaddingLog2); + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } // Write expert offsets to shared smemExpertOffset[threadIdx.x] = offset; @@ -532,7 +584,12 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // The first block writes out padded count if (blockIdx.x == 0 && warpIdx == KernelParams::MaxNumExperts / WarpSize - 1 && cute::elect_one_sync()) { - const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } @@ -543,9 +600,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = - min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), - mulLog2(ctaOffset, params.mPaddingLog2) + count); + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } } diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index e424d91db0..cae6729368 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -50,7 +50,7 @@ struct DataBase { // dim: [mNumTokens * mTopK] int32_t* mPtrExpandedIdxToPermutedIdx{nullptr}; // optional: if `nullptr`, it is not filled - // dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts] + // dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts] // Note: this array (mPtrPermutedIdxToTokenIdx) is uninitialized // Any out-of-bounds values are undefined. int32_t* mPtrPermutedIdxToTokenIdx{nullptr}; @@ -93,6 +93,7 @@ struct DataBase { int32_t mNumExperts; int32_t mTopK; int32_t mPaddingLog2; + int32_t mTileTokensDim; /// For expert parallelization int32_t mLocalExpertsStartIdx; @@ -100,11 +101,12 @@ struct DataBase { int32_t mNumLocalExperts; }; -template +template struct KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr int MaxNumExperts = MaxNumExperts_; + static constexpr bool isPow2 = isPow2_; static constexpr bool UsePdl = UsePdl_; // Public pointer members @@ -123,7 +125,8 @@ struct KernelParamsBase { int32_t mNumTokens = 0; int32_t mNumExperts = 0; - int32_t mPaddingLog2 = 0; + int32_t mPaddingLog2 = -1; + int32_t mTileTokensDim = 0; int32_t mLocalExpertsStartIdx = 0; int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; @@ -146,6 +149,7 @@ struct KernelParamsBase { mNumExperts = data.mNumExperts; mPaddingLog2 = data.mPaddingLog2; + mTileTokensDim = data.mTileTokensDim; mLocalExpertsStartIdx = data.mLocalExpertsStartIdx; mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2; mNumLocalExperts = data.mNumLocalExperts; @@ -173,8 +177,8 @@ struct Data : public DataBase { }; template -struct KernelParams : public KernelParamsBase { + bool isPow2_, bool UsePdl_> +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using BiasT = BiasT_; using OutputT = OutputT_; @@ -229,8 +233,8 @@ struct Data : public DataBase { tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -268,8 +272,8 @@ struct Data : public DataBase { }; template -struct KernelParams : public KernelParamsBase { + bool isPow2_, bool UsePdl_> +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index df19e00310..65f497ad90 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2084,80 +2084,57 @@ def run_moe_test( ) -# Test: DeepSeekV3 routing +# Test: Renormalize routing @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), ], ) @pytest.mark.parametrize( "routing_config", [ - pytest.param( - { - "num_experts": 384, - "top_k": 8, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - "compatible_intermediate_size": [512, 1024, 2048], - }, - id="kimi_k2", - ), pytest.param( { "num_experts": 256, "top_k": 8, "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - "compatible_intermediate_size": [512, 1024, 2048], + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [384, 768, 1024, 2048], }, - id="DSv3", + id="Renorm", ), pytest.param( { - "num_experts": 72, - "top_k": 6, + "num_experts": 512, + "top_k": 10, "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - "compatible_intermediate_size": [384, 768], + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [512], }, - id="DSLite", + id="Qwen3_next", ), ], ) @pytest.mark.parametrize( "weight_processing", [ - pytest.param( - { - "use_shuffled_weight": False, - "layout": WeightLayout.MajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="NoShuffle_MajorK", - ), pytest.param( { "use_shuffled_weight": True, @@ -2166,14 +2143,6 @@ def run_moe_test( }, id="Shuffled_MajorK", ), - pytest.param( - { - "use_shuffled_weight": True, - "layout": WeightLayout.BlockMajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="Shuffled_BlockMajorK", - ), ], ) @pytest.mark.parametrize( @@ -2183,7 +2152,7 @@ def run_moe_test( pytest.param(GatedActType.GeGlu, id="GeGlu"), ], ) -def test_deepseekv3_routing( +def test_renormalize_routing( num_tokens, hidden_size, intermediate_size, @@ -2193,7 +2162,7 @@ def test_deepseekv3_routing( gated_act_type, cache_permute_indices, ): - """Test DeepSeekV3 routing configurations.""" + """Test Renormalize routing configurations.""" run_moe_test( num_tokens, hidden_size, @@ -2206,58 +2175,80 @@ def test_deepseekv3_routing( ) -# Test: Renormalize routing +# Test: DeepSeekV3 routing @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) @pytest.mark.parametrize("hidden_size", [1024]) @pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), ], ) @pytest.mark.parametrize( "routing_config", [ + pytest.param( + { + "num_experts": 384, + "top_k": 8, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], + }, + id="kimi_k2", + ), pytest.param( { "num_experts": 256, "top_k": 8, "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], - "compatible_intermediate_size": [384, 768, 1024, 2048], + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], }, - id="Renorm", - marks=pytest.mark.skip(reason="Skip temporary"), + id="DSv3", ), pytest.param( { - "num_experts": 512, - "top_k": 10, + "num_experts": 72, + "top_k": 6, "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], - "compatible_intermediate_size": [512], + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [384, 768], }, - id="Qwen3_next", + id="DSLite", ), ], ) @pytest.mark.parametrize( "weight_processing", [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="NoShuffle_MajorK", + ), pytest.param( { "use_shuffled_weight": True, @@ -2266,6 +2257,14 @@ def test_deepseekv3_routing( }, id="Shuffled_MajorK", ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="Shuffled_BlockMajorK", + ), ], ) @pytest.mark.parametrize( @@ -2275,7 +2274,7 @@ def test_deepseekv3_routing( pytest.param(GatedActType.GeGlu, id="GeGlu"), ], ) -def test_renormalize_routing( +def test_deepseekv3_routing( num_tokens, hidden_size, intermediate_size, @@ -2285,7 +2284,7 @@ def test_renormalize_routing( gated_act_type, cache_permute_indices, ): - """Test Renormalize routing configurations.""" + """Test DeepSeekV3 routing configurations.""" run_moe_test( num_tokens, hidden_size,