diff --git a/benchmarks/bench_moe_deepseek.py b/benchmarks/bench_moe_deepseek.py index a2f032c5cc..9bf5b5493f 100644 --- a/benchmarks/bench_moe_deepseek.py +++ b/benchmarks/bench_moe_deepseek.py @@ -34,11 +34,24 @@ """ import argparse +from contextlib import contextmanager from dataclasses import dataclass import numpy as np import torch +@contextmanager +def cuda_profiler_range(name): + """Context manager for CUDA profiler + NVTX range.""" + torch.cuda.cudart().cudaProfilerStart() + torch.cuda.nvtx.range_push(name) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + + @dataclass class DeepSeekConfig: hidden_size: int = 7168 @@ -78,28 +91,44 @@ def is_sm100_family(): return props.major == 10 -def calc_tflops(n, ms, num_local_experts=None): - """Calculate TFLOPS for MoE computation. +NVFP4_BYTES = 9 / 16 # 0.5 bytes value + 1/16 byte block scale +BF16_BYTES = 2 + + +def calc_tflops(local_tokens, ms): + """Calculate TFLOPS using actual routed token count. - With EP, only tokens routed to local experts are computed. - Assumes uniform routing distribution across experts. + FC1: [M, H] x [H, 2I] + FC2: [M, I] x [I, H] + FLOPs = 2 * local_tokens * (H*2I + I*H) = 6 * local_tokens * H * I """ - if num_local_experts is None: - num_local_experts = CFG.num_experts + H = CFG.hidden_size + I = CFG.intermediate_size + flops = local_tokens * (2 * H * 2 * I + 2 * I * H) + return flops / (ms * 1e-3) / 1e12 - # Fraction of work done locally (assuming uniform distribution) - local_fraction = num_local_experts / CFG.num_experts - flops = ( - n - * CFG.top_k - * local_fraction # Only local expert pairs are computed - * ( - 2 * CFG.hidden_size * 2 * CFG.intermediate_size - + 2 * CFG.intermediate_size * CFG.hidden_size - ) +def calc_bw(local_tokens, active_experts, ms): + """Calculate achieved bandwidth in TB/s for MoE FC1 + FC2. + + Weights are read once per active expert. + FC1: nvfp4 input [M, H] x nvfp4 weight [H, 2I] -> nvfp4 output [M, 2I] + FC2: nvfp4 input [M, I] x nvfp4 weight [I, H] -> bf16 output [M, H] + """ + H = CFG.hidden_size + I = CFG.intermediate_size + + weight_bytes = active_experts * (H * 2 * I + I * H) * NVFP4_BYTES + + # Note: each implementation differs in how data is read/written within the module. + # So here, we only account for the MoE module's read/write bytes. + act_bytes = ( + local_tokens * H * NVFP4_BYTES # FC1 input read + + local_tokens * H * BF16_BYTES # FC2 output write ) - return flops / (ms * 1e-3) / 1e12 + + total_bytes = weight_bytes + act_bytes + return total_bytes / (ms * 1e-3) / 1e12 def interleave(x, gs=64): @@ -342,15 +371,16 @@ def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices): "topk_indices": ti, } - times = bench_gpu_time( - run, - dry_run_iters=warmup, - repeat_iters=iters, - cold_l2_cache=True, - enable_cupti=use_cupti, - use_cuda_graph=use_cuda_graph, - input_kwargs=input_kwargs, - ) + with cuda_profiler_range("bench_cute_dsl"): + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) return np.median(times) @@ -449,15 +479,16 @@ def run(hidden, sf, router_logits, routing_bias, topk_values, topk_indices): "topk_indices": ti, } - times = bench_gpu_time( - run, - dry_run_iters=warmup, - repeat_iters=iters, - cold_l2_cache=True, - enable_cupti=use_cupti, - use_cuda_graph=use_cuda_graph, - input_kwargs=input_kwargs, - ) + with cuda_profiler_range("bench_cutlass"): + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) return np.median(times) @@ -575,15 +606,16 @@ def run(routing_logits, routing_bias, hidden_states, hidden_states_scale): "hidden_states_scale": hsc, } - times = bench_gpu_time( - run, - dry_run_iters=warmup, - repeat_iters=iters, - cold_l2_cache=True, - enable_cupti=use_cupti, - use_cuda_graph=use_cuda_graph, - input_kwargs=input_kwargs, - ) + with cuda_profiler_range("bench_trtllm"): + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) return np.median(times) @@ -817,6 +849,7 @@ class BenchResult: tokens: int latency_ms: float tflops: float + bw_tb_s: float def run_benchmark( @@ -935,7 +968,9 @@ def _benchmark_single( ), } - # Build results + # Build results using actual routed token counts + local_tokens = histogram_record["local_tokens"] + active_experts = histogram_record["active_local_experts"] results = [] for backend, latency in lat.items(): results.append( @@ -943,7 +978,8 @@ def _benchmark_single( backend=backend, tokens=n, latency_ms=latency, - tflops=calc_tflops(n, latency, num_local), + tflops=calc_tflops(local_tokens, latency), + bw_tb_s=calc_bw(local_tokens, active_experts, latency), ) ) return results, histogram_record @@ -957,9 +993,9 @@ def _print_header( routing_bias_scale, ): """Print benchmark header.""" - print("\n" + "=" * 120) + print("\n" + "=" * 142) print(f"DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP={ep_config})") - print("=" * 120) + print("=" * 142) print( f"Model: hidden={CFG.hidden_size}, intermediate={CFG.intermediate_size}, " f"experts={CFG.num_experts}, top_k={CFG.top_k}" @@ -974,28 +1010,28 @@ def _print_header( f"Routing bias scale: {routing_bias_scale} " f"(larger values tend to create expert imbalance)" ) - print("-" * 120) + print("-" * 142) print( f"{'Tokens':>6} | " - f"{'CuteDSL':^15} | " - f"{'CUTLASS':^15} | " - f"{'TRTLLM':^15} | " + f"{'CuteDSL':^22} | " + f"{'CUTLASS':^22} | " + f"{'TRTLLM':^22} | " f"{'Speedup (CuteDSL/X)':^18} | " f"{'Winner':^8} | " f"{'Active':^7} | " - f"{'Stats':^14}" + f"{'Tokens/slot':^14}" ) print( f"{'':>6} | " - f"{'ms':>7} {'TFLOPS':>7} | " - f"{'ms':>7} {'TFLOPS':>7} | " - f"{'ms':>7} {'TFLOPS':>7} | " + f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " + f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " + f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " f"{'CUTLASS':>9} {'TRTLLM':>9} | " f"{'':^8} | " f"{'experts':^7} | " - f"{'min/max/median':^14}" + f"{'min/median/max':^14}" ) - print("-" * 120) + print("-" * 142) def _print_row(results, histogram_record): @@ -1014,14 +1050,14 @@ def _print_row(results, histogram_record): active_experts = f"{histogram_record['active_local_experts']:>3}" stats = ( f"{histogram_record['min_count']:>3}/" - f"{histogram_record['max_count']:>3}/" - f"{histogram_record['median_count']:>7.2f}" + f"{histogram_record['median_count']:>5.1f}/" + f"{histogram_record['max_count']:>4}" ) print( f"{cute.tokens:>6} | " - f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} | " - f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} | " - f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} | " + f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} {cute.bw_tb_s:>6.1f} | " + f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} {cutlass.bw_tb_s:>6.1f} | " + f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} {trtllm.bw_tb_s:>6.1f} | " f"{speedup_cutlass:>8.2f}x {speedup_trtllm:>8.2f}x | " f"{winner:^8} | " f"{active_experts:>7} | " @@ -1031,8 +1067,12 @@ def _print_row(results, histogram_record): def _print_footer(ep_config, num_local): """Print benchmark footer.""" - print("-" * 120) + print("-" * 142) print("Speedup > 1.0 means CuteDSL is faster than that backend") + print( + f"TFLOPS/BW use actual routed token counts. " + f"BW assumes nvfp4 = {NVFP4_BYTES:.4f} B/elem, bf16 = {BF16_BYTES} B/elem." + ) def _collect_expert_histogram(inputs, num_local, local_offset): @@ -1059,6 +1099,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset): ) local_hist = expert_hist[local_offset : local_offset + num_local] local_hist_f32 = local_hist.to(torch.float32) + local_tokens = int(local_hist.sum().item()) active_local_experts = int((local_hist > 0).sum().item()) if local_hist.numel() > 0: min_count = int(local_hist.min().item()) @@ -1070,6 +1111,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset): median_count = 0.0 return { + "local_tokens": local_tokens, "active_local_experts": active_local_experts, "min_count": min_count, "max_count": max_count, diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f4cb825d36..2797ef3b35 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -36,6 +36,24 @@ using namespace batchedGemm::trtllm::gen; static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache; +static inline bool skipQuirks(BatchedGemmConfig const& config) { + // Skip kernels that are known to hang/crash. Keep a record here for future reference. + auto const& options = config.mOptions; + // FC1 128x128 batchM mNumWarpsLoadSfA=4 (=2 is ok) + // bmm_E2m1_E2m1E2m1_Fp32_bA16_bB16_bC16_t128x128x256_s6_et128x64_m256x128x64_c2x1x1_32dp32b_rN_TN_schPd2x1x2x3_biasFp32N_bM_tma_ldgstsSf_rgTma_clmp_swiGlu_lsfaW4_dynB_sm100f + bool const isKnownHangingSm100fSwigluLsfaW4Family = + config.mSm == SmVersion::Sm100f && options.mDtypeA == tg::Dtype::E2m1 && + options.mDtypeB == tg::Dtype::E2m1 && options.mTileM == 128 && options.mTileN == 128 && + options.mTileK == 256 && options.mNumStages == 6 && options.mClusterDimX == 2 && + options.mClusterDimY == 1 && options.mClusterDimZ == 1 && + !doesRouteImplUseNoRoute(options.mRouteImpl) && !options.mTransposeMmaOutput && + options.mTileScheduler == TileScheduler::Persistent && options.mFusedAct && + options.mActType == batchedGemm::gemmGatedAct::ActType::SwiGlu && + options.mNumWarpsLoadSfA == 4; + + return isKnownHangingSm100fSwigluLsfaW4Family; +} + std::vector prioritizePredefinedConfigs( int m, int n, int k, std::vector const& sortedIndices, batchedGemm::batchedGemm::BatchedGemmConfig const* configs) { @@ -83,6 +101,29 @@ std::vector prioritizePredefinedConfigs( return prioritizedIndices; } +static inline void setProblemDimensions(BatchedGemmData& gemmData, bool transposeMmaOutput, + int32_t m, int32_t n, int32_t k, + std::vector const& batchedTokens, + int32_t numTokens, int32_t numBatches, + int32_t maxNumCtasInBatchDim) { + gemmData.mProblemDimensions.mNumBatches = numBatches; + gemmData.mProblemDimensions.mNumTokens = numTokens; + gemmData.mProblemDimensions.mBatchM = !transposeMmaOutput; + gemmData.mProblemDimensions.mBatchedM = + transposeMmaOutput ? std::vector{} : batchedTokens; + gemmData.mProblemDimensions.mBatchedN = + transposeMmaOutput ? batchedTokens : std::vector{}; + gemmData.mProblemDimensions.mM = transposeMmaOutput ? n : m; + gemmData.mProblemDimensions.mN = transposeMmaOutput ? m : n; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; +} + TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( TrtllmGenBatchedGemmRunnerOptions const& options_) : mOptions(options_) { @@ -93,16 +134,28 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( mPassingConfigIndices.clear(); for (size_t i = 0; i < bmm.getNumBatchedGemmConfigs(); ++i) { - auto const options = configs[i].mOptions; - auto const tileSize = mOptions.transposeMmaOutput ? options.mTileN : options.mTileM; - // When we include low-latency kernels we can set transposeMmaOutput via constructor - if (options.mDtypeA == mOptions.dtypeA && options.mDtypeB == mOptions.dtypeB && - options.mDtypeC == mOptions.dtypeC && options.mUseDeepSeekFp8 == mOptions.deepSeekFp8 && - options.mTransposeMmaOutput == mOptions.transposeMmaOutput && + // The kernel config. + auto const& config = configs[i]; + auto const& options = config.mOptions; + // The tile size in CGA granularity. + auto const tileSize = options.mTransposeMmaOutput ? options.mTileN * options.mClusterDimY + : options.mTileM * options.mClusterDimX; + // Check if kernel dtype matches runner config. + bool const dtypeMatch = + options.mTransposeMmaOutput + ? (options.mDtypeA == mOptions.dtypeB && options.mDtypeB == mOptions.dtypeA) + : (options.mDtypeA == mOptions.dtypeA && options.mDtypeB == mOptions.dtypeB); + // Check if kernel weight layout matches runner config. + bool const layoutAndShuffleMatch = + options.mTransposeMmaOutput ? (options.mUseShuffledMatrix == mOptions.useShuffledMatrix && + options.mLayoutA == mOptions.weightLayout) + : (options.mUseShuffledMatrix == mOptions.useShuffledMatrix && + options.mLayoutB == mOptions.weightLayout); + if (dtypeMatch && options.mDtypeC == mOptions.dtypeC && + options.mUseDeepSeekFp8 == mOptions.deepSeekFp8 && (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct && options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch && - tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix && - options.mLayoutA == mOptions.weightLayout) { + tileSize == mOptions.tileSize && layoutAndShuffleMatch) { if (options.mFusedAct) { if (options.mActType != static_cast(mOptions.actType)) { continue; @@ -111,8 +164,10 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) { continue; } - - if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) { + if (skipQuirks(config)) { + continue; + } + if (options.mEpilogueTileM == mOptions.epilogueTileM) { mPassingConfigIndices.push_back(i); } } @@ -126,7 +181,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( << ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8 << ", mActType: " << (int64_t)mOptions.actType << ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType - << ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput + << ", mTransposeMmaOutput: auto-tuned" << ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct << ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize; FLASHINFER_CHECK(!mPassingConfigIndices.empty(), error_msg.str()); @@ -135,31 +190,14 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim, int32_t configIndex) const { - BatchedGemmData gemmData{}; - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - auto bmm = BatchedGemmInterface(); - auto const configs = bmm.getBatchedGemmConfigs(); - auto const& config = configs[configIndex]; + BatchedGemmData gemmData{}; + setProblemDimensions(gemmData, config.mOptions.mTransposeMmaOutput, m, n, k, batchedTokens, + numTokens, numBatches, maxNumCtasInBatchDim); + return bmm.getWorkspaceSizeInBytes(config, gemmData); } @@ -179,7 +217,7 @@ void TrtllmGenBatchedGemmRunner::run( auto const configs = bmm.getBatchedGemmConfigs(); auto const& config = configs[configIndex]; - // printf("running config %d: %s\n", configIndex, config.mFunctionName); + bool const transposeMmaOutput = config.mOptions.mTransposeMmaOutput; FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0"); if (!mOptions.staticBatch) { @@ -204,35 +242,20 @@ void TrtllmGenBatchedGemmRunner::run( } // Dims - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; + setProblemDimensions(gemmData, transposeMmaOutput, m, n, k, batchedTokens, numTokens, numBatches, + maxNumCtasInBatchDim); // Inputs - gemmData.mInputBuffers.mPtrA = mOptions.transposeMmaOutput ? b : a; - gemmData.mInputBuffers.mPtrSfA = mOptions.transposeMmaOutput ? sfB : sfA; - gemmData.mInputBuffers.mPtrB = mOptions.transposeMmaOutput ? a : b; - gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB; + gemmData.mInputBuffers.mPtrA = transposeMmaOutput ? b : a; + gemmData.mInputBuffers.mPtrSfA = transposeMmaOutput ? sfB : sfA; + gemmData.mInputBuffers.mPtrB = transposeMmaOutput ? a : b; + gemmData.mInputBuffers.mPtrSfB = transposeMmaOutput ? sfA : sfB; gemmData.mInputBuffers.mPtrScaleC = scaleC; gemmData.mInputBuffers.mPtrScaleGate = scaleGateC; // For simplicity pass set scaleAct to scaleGateC gemmData.mInputBuffers.mPtrScaleAct = scaleGateC; - gemmData.mInputBuffers.mPtrPerTokenSfA = - mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA; - gemmData.mInputBuffers.mPtrPerTokenSfB = - mOptions.transposeMmaOutput ? perTokensSfA : perTokensSfB; + gemmData.mInputBuffers.mPtrPerTokenSfA = transposeMmaOutput ? perTokensSfB : perTokensSfA; + gemmData.mInputBuffers.mPtrPerTokenSfB = transposeMmaOutput ? perTokensSfA : perTokensSfB; gemmData.mInputBuffers.mPtrBias = ptrBias; gemmData.mInputBuffers.mPtrGatedActAlpha = ptrAlpha; gemmData.mInputBuffers.mPtrGatedActBeta = ptrBeta; @@ -240,8 +263,6 @@ void TrtllmGenBatchedGemmRunner::run( gemmData.mInputBuffers.mPtrRouteMap = routeMap; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - // Pointer to total number of padded tokens gemmData.mInputBuffers.mPtrTotalNumPaddedTokens = totalNumPaddedTokens; gemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; @@ -255,6 +276,15 @@ void TrtllmGenBatchedGemmRunner::run( int32_t multiProcessorCount; cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); + if (getBoolEnv("TRTLLM_BATCHED_GEMM_PRINT_NAME")) { + FLASHINFER_LOG("NumBatches", numBatches, ", MaxNumCgasInBatchDim", maxNumCtasInBatchDim, + ", MaxNumCtasInBatchDim", maxNumCtasInBatchDim, ", ShapeMNK", + gemmData.mProblemDimensions.mM, gemmData.mProblemDimensions.mN, + gemmData.mProblemDimensions.mK, ", ValidShapeMNK", + gemmData.mProblemDimensions.mValidM, gemmData.mProblemDimensions.mValidN, + gemmData.mProblemDimensions.mValidK, ", Kernel", config.mFunctionName); + } + // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere bmm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); @@ -266,7 +296,9 @@ void TrtllmGenBatchedGemmRunner::run( "Error occurred when running GEMM!" " (numBatches: ", numBatches, ", GemmMNK: ", m, " ", n, " ", k, ", Kernel: ", config.mFunctionName, - ")"); + ", transposeMmaOutput: ", transposeMmaOutput, ", configIndex: ", configIndex, + ", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, + ", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")"); } void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, @@ -328,30 +360,16 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); - BatchedGemmData gemmData{}; - // Dims - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - - auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) { + auto cmpFunc = [&configs, &bmm, &multiProcessorCount, &m, &n, &k, &batchedTokens, &numTokens, + &numBatches, &maxNumCtasInBatchDim](int64_t idx0, int64_t idx1) { auto const& optionsA = configs[idx0].mOptions; auto const& optionsB = configs[idx1].mOptions; - int32_t sizeK = gemmData.mProblemDimensions.mK; + int32_t sizeK = k; + + // Keep comparator stable across mixed transpose modes. + if (optionsA.mTransposeMmaOutput != optionsB.mTransposeMmaOutput) { + return optionsA.mTransposeMmaOutput; + } // Tier 0: K < tileK, prefer higher efficiency. if (optionsA.mTileK != optionsB.mTileK) { @@ -385,6 +403,9 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( // Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on // the larger side, prefer persistent tile scheduler. if (optionsA.mTileScheduler != optionsB.mTileScheduler) { + BatchedGemmData gemmData{}; + setProblemDimensions(gemmData, optionsA.mTransposeMmaOutput, m, n, k, batchedTokens, + numTokens, numBatches, maxNumCtasInBatchDim); auto options = bmm.getOptionsFromConfigAndData(configs[idx0], gemmData); auto numCtas = bmm.getNumCtas(options, gemmData.mProblemDimensions.mMaxNumCtasInTokenDim); if (numCtas > multiProcessorCount) { @@ -408,14 +429,34 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( // Filter out invalid configs. std::vector validConfigIndices; for (auto const& configIndex : prioritizedIndices) { + BatchedGemmData gemmData{}; + auto const transposeMmaOutput = configs[configIndex].mOptions.mTransposeMmaOutput; + setProblemDimensions(gemmData, transposeMmaOutput, m, n, k, batchedTokens, numTokens, + numBatches, maxNumCtasInBatchDim); auto isValidConfig = bmm.isValidConfig(configs[configIndex], gemmData); if (isValidConfig) { validConfigIndices.push_back(configIndex); } } - FLASHINFER_CHECK(!validConfigIndices.empty(), - "No valid config found for the given problem shape"); + std::ostringstream error_msg; + if (validConfigIndices.empty()) { + int64_t numTransposeConfigs = 0; + for (auto const& configIndex : prioritizedIndices) { + if (configs[configIndex].mOptions.mTransposeMmaOutput) { + ++numTransposeConfigs; + } + } + error_msg << "No valid config found for the given problem shape" + << " (m=" << m << ", n=" << n << ", k=" << k << ", numTokens=" << numTokens + << ", numBatches=" << numBatches << ", maxNumCtasInBatchDim=" << maxNumCtasInBatchDim + << ", passingConfigs=" << mPassingConfigIndices.size() + << ", prioritizedConfigs=" << prioritizedIndices.size() + << ", transposeConfigs=" << numTransposeConfigs + << ", nonTransposeConfigs=" << (prioritizedIndices.size() - numTransposeConfigs) + << ")"; + } + FLASHINFER_CHECK(!validConfigIndices.empty(), error_msg.str()); return validConfigIndices; } @@ -436,30 +477,11 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t int32_t maxNumCtasInBatchDim) const { auto const bmm = BatchedGemmInterface(); auto const configs = bmm.getBatchedGemmConfigs(); + auto const& config = configs[configIndex]; BatchedGemmData gemmData{}; - // Dims - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - - auto const& config = configs[configIndex]; + setProblemDimensions(gemmData, config.mOptions.mTransposeMmaOutput, m, n, k, batchedTokens, + numTokens, numBatches, maxNumCtasInBatchDim); return bmm.isValidConfig(config, gemmData); } diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 7c5826802e..ea3593016f 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -324,7 +324,7 @@ class FusedMoeLauncher { Tensor num_non_exiting_ctas; void prepare_routing_common() { - // Allocate routing phase workspace tensors + // Allocate routing phase workspace tensors. num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device()); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( @@ -360,7 +360,6 @@ class FusedMoeLauncher { workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = max_num_padded_tokens; - workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); workspace.expanded_idx_to_permuted_idx = @@ -390,7 +389,8 @@ class FusedMoeLauncher { TVM_FFI_ICHECK_EQ(hidden_states.ndim(), 2) << "hidden_states must be 2D."; } - // MoE computation phase workspace tensors (allocated in prepare_moe() or prepare_moe_common()) + // MoE computation phase workspace tensors (allocated after instantiate_moe_runner() and in + // prepare_moe()). Tensor gemm1_output; Tensor activation_output; Tensor gemm2_output; @@ -400,7 +400,7 @@ class FusedMoeLauncher { int64_t moe_tactic{-1}; std::unique_ptr moe_runner; - void prepare_moe_common(int64_t& moe_tactic) { + void instantiate_moe_runner(int64_t& moe_tactic) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; // For FP8 block-scale (E4m3 activations, E4m3 weights) with DeepSeek FP8, use the // weights-only Runner constructor to match the original kernel path and numerics. @@ -450,6 +450,8 @@ class FusedMoeLauncher { bool use_routing_scales_on_input = false, bool use_deep_seek_fp8 = false) { check_routing(); + // Runner dictates contract of routing table; must instantiate runner before prepare_routing + instantiate_moe_runner(moe_tactic); prepare_routing(); // Execute routing @@ -608,8 +610,6 @@ class Bf16MoeLauncher : public FusedMoeLauncher { } void prepare_moe(int64_t& moe_tactic) override { - FusedMoeLauncher::prepare_moe_common(moe_tactic); - int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; gemm1_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, hidden_states.device()); @@ -785,8 +785,6 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { } void prepare_moe(int64_t& moe_tactic) override { - FusedMoeLauncher::prepare_moe_common(moe_tactic); - int32_t max_num_padded_tokens_gemm1 = workspace.total_max_padded_tokens + args->num_experts; int32_t max_num_padded_tokens_gemm2 = workspace.total_max_padded_tokens; @@ -1105,8 +1103,6 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { } void prepare_moe(int64_t& moe_tactic) override { - FusedMoeLauncher::prepare_moe_common(moe_tactic); - // Calculate max_num_padded_tokens for gemm1 and gemm2 using maybeGetMinTokenCount int32_t max_num_padded_tokens_gemm1 = tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( @@ -1181,6 +1177,11 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { bool use_routing_scales_on_input = false, bool use_deep_seek_fp8 = false) override { check_routing(); + // Set DeepSeek mode before instantiating the runner so the correct + // constructor (weights-only vs act+weights) is chosen. + args->mUseDeepSeekFp8 = quantization_type == Fp8QuantizationType::DeepSeekFp8; + // Runner dictates contract of routing table; must instantiate runner before prepare_routing + instantiate_moe_runner(moe_tactic); prepare_routing(); cudaStream_t routing_stream = get_stream(hidden_states.device()); @@ -1360,8 +1361,6 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { args->output1_scales_gate_scalar = nullptr; args->output2_scales_scalar = nullptr; - FusedMoeLauncher::prepare_moe_common(moe_tactic); - max_num_padded_tokens_gemm1 = tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( workspace.total_max_padded_tokens, args->intermediate_size, @@ -1383,8 +1382,8 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { workspace.hidden_states_scale_linear = nullptr; // MxInt4 doesn't use linear scale workspace.gemm1_output = gemm1_output.data_ptr(); workspace.gemm1_output_scale = nullptr; - // Note: activation_output and activation_output_scale are set by the base class - // prepare_moe_common() when gated activation is used + // activation_output and activation_output_scale are configured on the gated-activation path in + // this launcher's prepare_moe() implementation. workspace.gemm2_output = gemm2_output.data_ptr(); workspace.gemm2_output_scale = nullptr; } @@ -1511,7 +1510,6 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = max_num_padded_tokens; - workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(const_cast(expert_indices.data_ptr())); workspace.expert_weights = const_cast(expert_weights.data_ptr()); @@ -1605,8 +1603,6 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { ? static_cast(output2_scales_scalar.value().data_ptr()) : nullptr; - FusedMoeLauncher::prepare_moe_common(moe_tactic); - auto const sf_vec_size = mDtypeWeights == btg::Dtype::MxE2m1 ? 32 : 16; max_num_padded_tokens_gemm1 = @@ -1640,8 +1636,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { workspace.gemm1_output_scale = gemm1_output_scale.has_value() ? static_cast(gemm1_output_scale.value().data_ptr()) : nullptr; - // Note: activation_output and activation_output_scale are set by the base class - // prepare_moe_common() when gated activation is used + // activation_output and activation_output_scale are configured on the gated-activation path in + // this launcher's prepare_moe() implementation. workspace.gemm2_output = gemm2_output.data_ptr(); workspace.gemm2_output_scale = nullptr; } @@ -1666,6 +1662,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { bool use_routing_scales_on_input = false, bool use_deep_seek_fp8 = false) override { check_routing(); + // Runner dictates contract of routing table; must instantiate runner before prepare_routing + instantiate_moe_runner(moe_tactic); prepare_routing(); // Execute routing diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 40150ad86d..1160a74669 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include "flashinfer/exception.h" #include "flashinfer/trtllm/batched_gemm/KernelRunner.h" @@ -361,16 +361,14 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( if (isGatedAct) { ActType actType = activationTypeToGatedActType(activationType); tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { - // Swap A and B dtypes because transposeMmaOutput is hardcoded to true - .dtypeA = dtypeWeights, - .dtypeB = dtypeAct, + .dtypeA = dtypeAct, + .dtypeB = dtypeWeights, .dtypeC = dtypeAct, .actType = actType, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = !useDeepSeekFp8, .routeAct = true, .staticBatch = false, - .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, .useShuffledMatrix = useShuffledMatrix, @@ -379,16 +377,14 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( } else { EltwiseActType actType = activationTypeToEltwiseActType(activationType); tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { - // Swap A and B dtypes because transposeMmaOutput is hardcoded to true - .dtypeA = dtypeWeights, - .dtypeB = dtypeAct, + .dtypeA = dtypeAct, + .dtypeB = dtypeWeights, .dtypeC = dtypeAct, .eltwiseActType = actType, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = false, .routeAct = true, .staticBatch = false, - .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = 128, .useShuffledMatrix = useShuffledMatrix, @@ -475,16 +471,14 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, int32_t tileTokensDim, bool useDeepSeekFp8, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { - // Swap A and B dtypes because transposeMmaOutput is hardcoded to true - .dtypeA = dtypeWeights, - .dtypeB = dtypeAct, + .dtypeA = dtypeAct, + .dtypeB = dtypeWeights, .dtypeC = dtypeOut, .eltwiseActType = EltwiseActType::None, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = false, .routeAct = false, .staticBatch = false, - .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, .useShuffledMatrix = useShuffledMatrix, diff --git a/include/flashinfer/exception.h b/include/flashinfer/exception.h index aaaa2b5b3e..62d8422062 100644 --- a/include/flashinfer/exception.h +++ b/include/flashinfer/exception.h @@ -70,6 +70,17 @@ void write_to_stream(std::ostringstream& oss, T&& val, Args&&... args) { flashinfer::Warning(__FUNCTION__, __FILE__, __LINE__, msg).emit(); \ } while (0) +#define FLASHINFER_LOG(...) \ + do { \ + std::ostringstream oss; \ + write_to_stream(oss, ##__VA_ARGS__); \ + std::string msg = oss.str(); \ + if (msg.empty()) { \ + msg = "Log triggered"; \ + } \ + flashinfer::Log(__FUNCTION__, __FILE__, __LINE__, msg).emit(); \ + } while (0) + namespace flashinfer { class Error : public std::exception { private: @@ -101,6 +112,21 @@ class Warning { void emit() const { std::cerr << message_ << std::endl; } }; +class Log { + private: + std::string message_; + + public: + Log(const std::string& func, const std::string& file, int line, const std::string& message) { + std::ostringstream oss; + oss << "Log in function '" << func << "' " + << "at " << file << ":" << line << ": " << message; + message_ = oss.str(); + } + + void emit() const { std::cerr << message_ << std::endl; } +}; + } // namespace flashinfer #endif // FLASHINFER_EXCEPTION_H_ diff --git a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h index 64e958d9e8..9f21147d05 100644 --- a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h +++ b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h @@ -63,8 +63,11 @@ enum class EltwiseActType { }; struct TrtllmGenBatchedGemmRunnerOptions { + // Canonically, A is activation. batchedGemm::trtllm::gen::Dtype dtypeA; + // B is weight. batchedGemm::trtllm::gen::Dtype dtypeB; + // C is output. batchedGemm::trtllm::gen::Dtype dtypeC; ActType actType{ActType::SwiGlu}; EltwiseActType eltwiseActType{EltwiseActType::None}; @@ -72,6 +75,7 @@ struct TrtllmGenBatchedGemmRunnerOptions { bool fusedAct{false}; bool routeAct{false}; bool staticBatch{false}; + // If transposeMmaOutput is true, then A and B are swapped under the hood. bool transposeMmaOutput{false}; int32_t tileSize{8}; int32_t epilogueTileM{128}; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index df46aeed0b..fd171525be 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -86,6 +86,8 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod }; } +// NOTE: The legacy CTA-based name is kept for interface compatibility even though these entries are +// counted at CGA granularity. inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t tileTokensDim) { // For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime. @@ -347,6 +349,8 @@ struct MoEWorkspace { // consumed by permuteGemm1 kernel void* token_scales = nullptr; + // NOTE: The legacy CTA-based name is kept for interface compatibility even though these entries + // are counted at CGA granularity. int32_t* cta_idx_xy_to_batch_idx = nullptr; int32_t* cta_idx_xy_to_mn_limit = nullptr; int32_t* num_non_exiting_ctas = nullptr; @@ -358,7 +362,6 @@ struct MoEWorkspace { float* permuted_hidden_states_scale = nullptr; // Gemm1 intermediate outputs: - int32_t ProjUpTileN{0}; void* gemm1_output = nullptr; float* gemm1_output_scale = nullptr;