diff --git a/benchmarks/bench_moe_deepseek.py b/benchmarks/bench_moe_deepseek.py index 191be0604b..d17b3aee4f 100644 --- a/benchmarks/bench_moe_deepseek.py +++ b/benchmarks/bench_moe_deepseek.py @@ -34,11 +34,24 @@ """ import argparse +from contextlib import contextmanager from dataclasses import dataclass import numpy as np import torch +@contextmanager +def cuda_profiler_range(name): + """Context manager for CUDA profiler + NVTX range.""" + torch.cuda.cudart().cudaProfilerStart() + torch.cuda.nvtx.range_push(name) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + + @dataclass class DeepSeekConfig: hidden_size: int = 7168 @@ -79,28 +92,44 @@ def is_sm100_family(): return props.major == 10 -def calc_tflops(n, ms, num_local_experts=None): - """Calculate TFLOPS for MoE computation. +NVFP4_BYTES = 9 / 16 # 0.5 bytes value + 1/16 byte block scale +BF16_BYTES = 2 + + +def calc_tflops(local_tokens, ms): + """Calculate TFLOPS using actual routed token count. - With EP, only tokens routed to local experts are computed. - Assumes uniform routing distribution across experts. + FC1: [M, H] x [H, 2I] + FC2: [M, I] x [I, H] + FLOPs = 2 * local_tokens * (H*2I + I*H) = 6 * local_tokens * H * I """ - if num_local_experts is None: - num_local_experts = CFG.num_experts + H = CFG.hidden_size + I = CFG.intermediate_size + flops = local_tokens * (2 * H * 2 * I + 2 * I * H) + return flops / (ms * 1e-3) / 1e12 - # Fraction of work done locally (assuming uniform distribution) - local_fraction = num_local_experts / CFG.num_experts - flops = ( - n - * CFG.top_k - * local_fraction # Only local expert pairs are computed - * ( - 2 * CFG.hidden_size * 2 * CFG.intermediate_size - + 2 * CFG.intermediate_size * CFG.hidden_size - ) +def calc_bw(local_tokens, active_experts, ms): + """Calculate achieved bandwidth in TB/s for MoE FC1 + FC2. + + Weights are read once per active expert. + FC1: nvfp4 input [M, H] x nvfp4 weight [H, 2I] -> nvfp4 output [M, 2I] + FC2: nvfp4 input [M, I] x nvfp4 weight [I, H] -> bf16 output [M, H] + """ + H = CFG.hidden_size + I = CFG.intermediate_size + + weight_bytes = active_experts * (H * 2 * I + I * H) * NVFP4_BYTES + + act_bytes = ( + local_tokens * H * NVFP4_BYTES # FC1 input read + + local_tokens * 2 * I * NVFP4_BYTES # FC1 output write + + local_tokens * I * NVFP4_BYTES # FC2 input read + + local_tokens * H * BF16_BYTES # FC2 output write ) - return flops / (ms * 1e-3) / 1e12 + + total_bytes = weight_bytes + act_bytes + return total_bytes / (ms * 1e-3) / 1e12 def interleave(x, gs=64): @@ -343,15 +372,16 @@ def run(x, x_sf, router_logits, routing_bias, topk_values, topk_indices): "topk_indices": ti, } - times = bench_gpu_time( - run, - dry_run_iters=warmup, - repeat_iters=iters, - cold_l2_cache=True, - enable_cupti=use_cupti, - use_cuda_graph=use_cuda_graph, - input_kwargs=input_kwargs, - ) + with cuda_profiler_range("bench_cute_dsl"): + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) return np.median(times) @@ -450,15 +480,16 @@ def run(hidden, sf, router_logits, routing_bias, topk_values, topk_indices): "topk_indices": ti, } - times = bench_gpu_time( - run, - dry_run_iters=warmup, - repeat_iters=iters, - cold_l2_cache=True, - enable_cupti=use_cupti, - use_cuda_graph=use_cuda_graph, - input_kwargs=input_kwargs, - ) + with cuda_profiler_range("bench_cutlass"): + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) return np.median(times) @@ -576,15 +607,16 @@ def run(routing_logits, routing_bias, hidden_states, hidden_states_scale): "hidden_states_scale": hsc, } - times = bench_gpu_time( - run, - dry_run_iters=warmup, - repeat_iters=iters, - cold_l2_cache=True, - enable_cupti=use_cupti, - use_cuda_graph=use_cuda_graph, - input_kwargs=input_kwargs, - ) + with cuda_profiler_range("bench_trtllm"): + times = bench_gpu_time( + run, + dry_run_iters=warmup, + repeat_iters=iters, + cold_l2_cache=True, + enable_cupti=use_cupti, + use_cuda_graph=use_cuda_graph, + input_kwargs=input_kwargs, + ) return np.median(times) @@ -818,6 +850,7 @@ class BenchResult: tokens: int latency_ms: float tflops: float + bw_tb_s: float def run_benchmark( @@ -936,7 +969,9 @@ def _benchmark_single( ), } - # Build results + # Build results using actual routed token counts + local_tokens = histogram_record["local_tokens"] + active_experts = histogram_record["active_local_experts"] results = [] for backend, latency in lat.items(): results.append( @@ -944,7 +979,8 @@ def _benchmark_single( backend=backend, tokens=n, latency_ms=latency, - tflops=calc_tflops(n, latency, num_local), + tflops=calc_tflops(local_tokens, latency), + bw_tb_s=calc_bw(local_tokens, active_experts, latency), ) ) return results, histogram_record @@ -958,9 +994,9 @@ def _print_header( routing_bias_scale, ): """Print benchmark header.""" - print("\n" + "=" * 120) + print("\n" + "=" * 142) print(f"DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP={ep_config})") - print("=" * 120) + print("=" * 142) print( f"Model: hidden={CFG.hidden_size}, intermediate={CFG.intermediate_size}, " f"experts={CFG.num_experts}, top_k={CFG.top_k}" @@ -975,28 +1011,28 @@ def _print_header( f"Routing bias scale: {routing_bias_scale} " f"(larger values tend to create expert imbalance)" ) - print("-" * 120) + print("-" * 142) print( f"{'Tokens':>6} | " - f"{'CuteDSL':^15} | " - f"{'CUTLASS':^15} | " - f"{'TRTLLM':^15} | " + f"{'CuteDSL':^22} | " + f"{'CUTLASS':^22} | " + f"{'TRTLLM':^22} | " f"{'Speedup (CuteDSL/X)':^18} | " f"{'Winner':^8} | " f"{'Active':^7} | " - f"{'Stats':^14}" + f"{'Tokens/slot':^14}" ) print( f"{'':>6} | " - f"{'ms':>7} {'TFLOPS':>7} | " - f"{'ms':>7} {'TFLOPS':>7} | " - f"{'ms':>7} {'TFLOPS':>7} | " + f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " + f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " + f"{'ms':>7} {'TFLOPS':>7} {'TB/s':>6} | " f"{'CUTLASS':>9} {'TRTLLM':>9} | " f"{'':^8} | " f"{'experts':^7} | " - f"{'min/max/median':^14}" + f"{'min/median/max':^14}" ) - print("-" * 120) + print("-" * 142) def _print_row(results, histogram_record): @@ -1015,14 +1051,14 @@ def _print_row(results, histogram_record): active_experts = f"{histogram_record['active_local_experts']:>3}" stats = ( f"{histogram_record['min_count']:>3}/" - f"{histogram_record['max_count']:>3}/" - f"{histogram_record['median_count']:>7.2f}" + f"{histogram_record['median_count']:>5.1f}/" + f"{histogram_record['max_count']:>4}" ) print( f"{cute.tokens:>6} | " - f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} | " - f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} | " - f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} | " + f"{cute.latency_ms:>7.3f} {cute.tflops:>7.1f} {cute.bw_tb_s:>6.1f} | " + f"{cutlass.latency_ms:>7.3f} {cutlass.tflops:>7.1f} {cutlass.bw_tb_s:>6.1f} | " + f"{trtllm.latency_ms:>7.3f} {trtllm.tflops:>7.1f} {trtllm.bw_tb_s:>6.1f} | " f"{speedup_cutlass:>8.2f}x {speedup_trtllm:>8.2f}x | " f"{winner:^8} | " f"{active_experts:>7} | " @@ -1032,8 +1068,12 @@ def _print_row(results, histogram_record): def _print_footer(ep_config, num_local): """Print benchmark footer.""" - print("-" * 120) + print("-" * 142) print("Speedup > 1.0 means CuteDSL is faster than that backend") + print( + f"TFLOPS/BW use actual routed token counts. " + f"BW assumes nvfp4 = {NVFP4_BYTES:.4f} B/elem, bf16 = {BF16_BYTES} B/elem." + ) def _collect_expert_histogram(inputs, num_local, local_offset): @@ -1060,6 +1100,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset): ) local_hist = expert_hist[local_offset : local_offset + num_local] local_hist_f32 = local_hist.to(torch.float32) + local_tokens = int(local_hist.sum().item()) active_local_experts = int((local_hist > 0).sum().item()) if local_hist.numel() > 0: min_count = int(local_hist.min().item()) @@ -1071,6 +1112,7 @@ def _collect_expert_histogram(inputs, num_local, local_offset): median_count = 0.0 return { + "local_tokens": local_tokens, "active_local_experts": active_local_experts, "min_count": min_count, "max_count": max_count, diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f4cb825d36..bad75c2cc5 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -36,6 +36,10 @@ using namespace batchedGemm::trtllm::gen; static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache; +static inline int32_t getClusterSizeInBatchDim(BatchedGemmOptions const& options) { + return options.mTransposeMmaOutput ? options.mClusterDimY : options.mClusterDimX; +} + std::vector prioritizePredefinedConfigs( int m, int n, int k, std::vector const& sortedIndices, batchedGemm::batchedGemm::BatchedGemmConfig const* configs) { @@ -83,6 +87,30 @@ std::vector prioritizePredefinedConfigs( return prioritizedIndices; } +static inline void setProblemDimensions(BatchedGemmData& gemmData, bool transposeMmaOutput, + int32_t m, int32_t n, int32_t k, + std::vector const& batchedTokens, + int32_t numTokens, int32_t numBatches, + int32_t maxNumCgasInBatchDim, + int32_t clusterSizeInBatchDim) { + gemmData.mProblemDimensions.mNumBatches = numBatches; + gemmData.mProblemDimensions.mNumTokens = numTokens; + gemmData.mProblemDimensions.mBatchM = !transposeMmaOutput; + gemmData.mProblemDimensions.mBatchedM = + transposeMmaOutput ? std::vector{} : batchedTokens; + gemmData.mProblemDimensions.mBatchedN = + transposeMmaOutput ? batchedTokens : std::vector{}; + gemmData.mProblemDimensions.mM = transposeMmaOutput ? n : m; + gemmData.mProblemDimensions.mN = transposeMmaOutput ? m : n; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCgasInBatchDim * clusterSizeInBatchDim; +} + TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( TrtllmGenBatchedGemmRunnerOptions const& options_) : mOptions(options_) { @@ -93,16 +121,27 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( mPassingConfigIndices.clear(); for (size_t i = 0; i < bmm.getNumBatchedGemmConfigs(); ++i) { + // The kernel config. auto const options = configs[i].mOptions; - auto const tileSize = mOptions.transposeMmaOutput ? options.mTileN : options.mTileM; - // When we include low-latency kernels we can set transposeMmaOutput via constructor - if (options.mDtypeA == mOptions.dtypeA && options.mDtypeB == mOptions.dtypeB && - options.mDtypeC == mOptions.dtypeC && options.mUseDeepSeekFp8 == mOptions.deepSeekFp8 && - options.mTransposeMmaOutput == mOptions.transposeMmaOutput && + // The tile size in CGA granularity. + auto const tileSize = options.mTransposeMmaOutput ? options.mTileN * options.mClusterDimY + : options.mTileM * options.mClusterDimX; + // Check if kernel dtype matches runner config. + bool const dtypeMatch = + options.mTransposeMmaOutput + ? (options.mDtypeA == mOptions.dtypeB && options.mDtypeB == mOptions.dtypeA) + : (options.mDtypeA == mOptions.dtypeA && options.mDtypeB == mOptions.dtypeB); + // Check if kernel weight layout matches runner config. + bool const layoutAndShuffleMatch = + options.mTransposeMmaOutput ? (options.mUseShuffledMatrix == mOptions.useShuffledMatrix && + options.mLayoutA == mOptions.weightLayout) + : (options.mUseShuffledMatrix == mOptions.useShuffledMatrix && + options.mLayoutB == mOptions.weightLayout); + if (dtypeMatch && options.mDtypeC == mOptions.dtypeC && + options.mUseDeepSeekFp8 == mOptions.deepSeekFp8 && (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct && options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch && - tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix && - options.mLayoutA == mOptions.weightLayout) { + tileSize == mOptions.tileSize && layoutAndShuffleMatch) { if (options.mFusedAct) { if (options.mActType != static_cast(mOptions.actType)) { continue; @@ -111,8 +150,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) { continue; } - - if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) { + if (options.mEpilogueTileM == mOptions.epilogueTileM) { mPassingConfigIndices.push_back(i); } } @@ -126,46 +164,40 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( << ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8 << ", mActType: " << (int64_t)mOptions.actType << ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType - << ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput + << ", mTransposeMmaOutput: auto-tuned" << ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct << ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize; FLASHINFER_CHECK(!mPassingConfigIndices.empty(), error_msg.str()); } +int32_t TrtllmGenBatchedGemmRunner::getConfigClusterSizeInBatchDim(int32_t configIndex) const { + auto const bmm = BatchedGemmInterface(); + auto const configs = bmm.getBatchedGemmConfigs(); + int64_t const numConfigs = static_cast(bmm.getNumBatchedGemmConfigs()); + FLASHINFER_CHECK(configIndex >= 0 && configIndex < numConfigs, + "Invalid batched GEMM config index ", configIndex, ". Valid range is [0, ", + numConfigs - 1, "]."); + return getClusterSizeInBatchDim(configs[configIndex].mOptions); +} + size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, - int32_t numBatches, int32_t maxNumCtasInBatchDim, int32_t configIndex) const { - BatchedGemmData gemmData{}; - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - + int32_t numBatches, int32_t maxNumCgasInBatchDim, int32_t configIndex) const { auto bmm = BatchedGemmInterface(); - auto const configs = bmm.getBatchedGemmConfigs(); - auto const& config = configs[configIndex]; + BatchedGemmData gemmData{}; + int32_t const clusterSizeInBatchDim = getClusterSizeInBatchDim(config.mOptions); + setProblemDimensions(gemmData, config.mOptions.mTransposeMmaOutput, m, n, k, batchedTokens, + numTokens, numBatches, maxNumCgasInBatchDim, clusterSizeInBatchDim); + return bmm.getWorkspaceSizeInBytes(config, gemmData); } void TrtllmGenBatchedGemmRunner::run( int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, - int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b, + int32_t numBatches, int32_t maxNumCgasInBatchDim, void const* a, void const* sfA, void const* b, void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC, float const* ptrBias, float const* ptrAlpha, float const* ptrBeta, float const* ptrClampLimit, void* c, void* outSfC, int32_t const* routeMap, @@ -179,6 +211,9 @@ void TrtllmGenBatchedGemmRunner::run( auto const configs = bmm.getBatchedGemmConfigs(); auto const& config = configs[configIndex]; + bool const transposeMmaOutput = config.mOptions.mTransposeMmaOutput; + int32_t const clusterSizeInBatchDim = getClusterSizeInBatchDim(config.mOptions); + int32_t const maxNumCtasInBatchDim = maxNumCgasInBatchDim * clusterSizeInBatchDim; // printf("running config %d: %s\n", configIndex, config.mFunctionName); FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0"); @@ -194,8 +229,8 @@ void TrtllmGenBatchedGemmRunner::run( } if (!mOptions.staticBatch && numTokens != 0) { - FLASHINFER_CHECK(maxNumCtasInBatchDim > 0, - "Batched GEMM with dynamic batching requires maxNumCtasInBatchDim > 0"); + FLASHINFER_CHECK(maxNumCgasInBatchDim > 0, + "Batched GEMM with dynamic batching requires maxNumCgasInBatchDim > 0"); } if (mOptions.routeAct) { @@ -204,35 +239,20 @@ void TrtllmGenBatchedGemmRunner::run( } // Dims - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; + setProblemDimensions(gemmData, transposeMmaOutput, m, n, k, batchedTokens, numTokens, numBatches, + maxNumCgasInBatchDim, clusterSizeInBatchDim); // Inputs - gemmData.mInputBuffers.mPtrA = mOptions.transposeMmaOutput ? b : a; - gemmData.mInputBuffers.mPtrSfA = mOptions.transposeMmaOutput ? sfB : sfA; - gemmData.mInputBuffers.mPtrB = mOptions.transposeMmaOutput ? a : b; - gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB; + gemmData.mInputBuffers.mPtrA = transposeMmaOutput ? b : a; + gemmData.mInputBuffers.mPtrSfA = transposeMmaOutput ? sfB : sfA; + gemmData.mInputBuffers.mPtrB = transposeMmaOutput ? a : b; + gemmData.mInputBuffers.mPtrSfB = transposeMmaOutput ? sfA : sfB; gemmData.mInputBuffers.mPtrScaleC = scaleC; gemmData.mInputBuffers.mPtrScaleGate = scaleGateC; // For simplicity pass set scaleAct to scaleGateC gemmData.mInputBuffers.mPtrScaleAct = scaleGateC; - gemmData.mInputBuffers.mPtrPerTokenSfA = - mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA; - gemmData.mInputBuffers.mPtrPerTokenSfB = - mOptions.transposeMmaOutput ? perTokensSfA : perTokensSfB; + gemmData.mInputBuffers.mPtrPerTokenSfA = transposeMmaOutput ? perTokensSfB : perTokensSfA; + gemmData.mInputBuffers.mPtrPerTokenSfB = transposeMmaOutput ? perTokensSfA : perTokensSfB; gemmData.mInputBuffers.mPtrBias = ptrBias; gemmData.mInputBuffers.mPtrGatedActAlpha = ptrAlpha; gemmData.mInputBuffers.mPtrGatedActBeta = ptrBeta; @@ -240,8 +260,6 @@ void TrtllmGenBatchedGemmRunner::run( gemmData.mInputBuffers.mPtrRouteMap = routeMap; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - // Pointer to total number of padded tokens gemmData.mInputBuffers.mPtrTotalNumPaddedTokens = totalNumPaddedTokens; gemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; @@ -255,6 +273,15 @@ void TrtllmGenBatchedGemmRunner::run( int32_t multiProcessorCount; cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device); + if (getBoolEnv("TRTLLM_BATCHED_GEMM_PRINT_NAME")) { + FLASHINFER_LOG("NumBatches", numBatches, ", MaxNumCgasInBatchDim", maxNumCgasInBatchDim, + ", MaxNumCtasInBatchDim", maxNumCtasInBatchDim, ", ShapeMNK", + gemmData.mProblemDimensions.mM, gemmData.mProblemDimensions.mN, + gemmData.mProblemDimensions.mK, ", ValidShapeMNK", + gemmData.mProblemDimensions.mValidM, gemmData.mProblemDimensions.mValidN, + gemmData.mProblemDimensions.mValidK, ", Kernel", config.mFunctionName); + } + // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere bmm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); @@ -266,7 +293,9 @@ void TrtllmGenBatchedGemmRunner::run( "Error occurred when running GEMM!" " (numBatches: ", numBatches, ", GemmMNK: ", m, " ", n, " ", k, ", Kernel: ", config.mFunctionName, - ")"); + ", transposeMmaOutput: ", transposeMmaOutput, ", configIndex: ", configIndex, + ", maxNumCgasInBatchDim: ", maxNumCgasInBatchDim, + ", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")"); } void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, @@ -275,7 +304,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, void* outSfC, void* workspace, CUstream stream, int device, int32_t configIndex, bool enable_pdl) { // Dispatch with block scaling factors and with static batching. - run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, + run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCgasInBatchDim */ 0, a, sfA, b, sfB, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, /* scaleC */ nullptr, /* scaleGateC */ nullptr, /* ptrBias */ nullptr, /* ptrAlpha */ nullptr, @@ -293,7 +322,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, void* outSfC, void* workspace, CUstream stream, int device, int32_t configIndex, bool enable_pdl) { // Dispatch with block scaling factors and with static batching. - run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, + run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCgasInBatchDim */ 0, a, sfA, b, sfB, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, /* scaleC */ nullptr, /* scaleGateC */ nullptr, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, c, @@ -309,7 +338,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, void* c, void* workspace, CUstream stream, int device, int32_t configIndex, bool enable_pdl) { // Dispatch with block scaling factors and with static batching. - run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, + run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCgasInBatchDim */ 0, a, /* sfA */ nullptr, b, /* sfB */ nullptr, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, scaleC, scaleGateC, /* ptrBias */ nullptr, /* ptrAlpha */ nullptr, @@ -322,36 +351,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, - int32_t numBatches, int32_t maxNumCtasInBatchDim) const { + int32_t numBatches, int32_t maxNumCgasInBatchDim) const { auto const bmm = BatchedGemmInterface(); auto const configs = bmm.getBatchedGemmConfigs(); int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); - BatchedGemmData gemmData{}; - // Dims - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - - auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) { + auto cmpFunc = [&configs, &bmm, &multiProcessorCount, &m, &n, &k, &batchedTokens, &numTokens, + &numBatches, &maxNumCgasInBatchDim](int64_t idx0, int64_t idx1) { auto const& optionsA = configs[idx0].mOptions; auto const& optionsB = configs[idx1].mOptions; - int32_t sizeK = gemmData.mProblemDimensions.mK; + int32_t sizeK = k; + + // Keep comparator stable across mixed transpose modes. + if (optionsA.mTransposeMmaOutput != optionsB.mTransposeMmaOutput) { + return optionsA.mTransposeMmaOutput; + } // Tier 0: K < tileK, prefer higher efficiency. if (optionsA.mTileK != optionsB.mTileK) { @@ -385,6 +400,10 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( // Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on // the larger side, prefer persistent tile scheduler. if (optionsA.mTileScheduler != optionsB.mTileScheduler) { + BatchedGemmData gemmData{}; + int32_t const clusterSizeInBatchDim = getClusterSizeInBatchDim(optionsA); + setProblemDimensions(gemmData, optionsA.mTransposeMmaOutput, m, n, k, batchedTokens, + numTokens, numBatches, maxNumCgasInBatchDim, clusterSizeInBatchDim); auto options = bmm.getOptionsFromConfigAndData(configs[idx0], gemmData); auto numCtas = bmm.getNumCtas(options, gemmData.mProblemDimensions.mMaxNumCtasInTokenDim); if (numCtas > multiProcessorCount) { @@ -408,23 +427,44 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( // Filter out invalid configs. std::vector validConfigIndices; for (auto const& configIndex : prioritizedIndices) { + BatchedGemmData gemmData{}; + auto const transposeMmaOutput = configs[configIndex].mOptions.mTransposeMmaOutput; + int32_t const clusterSizeInBatchDim = getClusterSizeInBatchDim(configs[configIndex].mOptions); + setProblemDimensions(gemmData, transposeMmaOutput, m, n, k, batchedTokens, numTokens, + numBatches, maxNumCgasInBatchDim, clusterSizeInBatchDim); auto isValidConfig = bmm.isValidConfig(configs[configIndex], gemmData); if (isValidConfig) { validConfigIndices.push_back(configIndex); } } - FLASHINFER_CHECK(!validConfigIndices.empty(), - "No valid config found for the given problem shape"); + std::ostringstream error_msg; + if (validConfigIndices.empty()) { + int64_t numTransposeConfigs = 0; + for (auto const& configIndex : prioritizedIndices) { + if (configs[configIndex].mOptions.mTransposeMmaOutput) { + ++numTransposeConfigs; + } + } + error_msg << "No valid config found for the given problem shape" + << " (m=" << m << ", n=" << n << ", k=" << k << ", numTokens=" << numTokens + << ", numBatches=" << numBatches << ", maxNumCgasInBatchDim=" << maxNumCgasInBatchDim + << ", passingConfigs=" << mPassingConfigIndices.size() + << ", prioritizedConfigs=" << prioritizedIndices.size() + << ", transposeConfigs=" << numTransposeConfigs + << ", nonTransposeConfigs=" << (prioritizedIndices.size() - numTransposeConfigs) + << ")"; + } + FLASHINFER_CHECK(!validConfigIndices.empty(), error_msg.str()); return validConfigIndices; } int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex( int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, - int32_t numBatches, int32_t maxNumCtasInBatchDim) const { + int32_t numBatches, int32_t maxNumCgasInBatchDim) const { auto const validConfigIndices = - getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim); + getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCgasInBatchDim); return validConfigIndices[0]; } @@ -433,33 +473,15 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t int32_t k, std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, - int32_t maxNumCtasInBatchDim) const { + int32_t maxNumCgasInBatchDim) const { auto const bmm = BatchedGemmInterface(); auto const configs = bmm.getBatchedGemmConfigs(); + auto const& config = configs[configIndex]; BatchedGemmData gemmData{}; - // Dims - gemmData.mProblemDimensions.mNumBatches = numBatches; - gemmData.mProblemDimensions.mNumTokens = numTokens; - gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; - gemmData.mProblemDimensions.mBatchedM = - mOptions.transposeMmaOutput ? std::vector{} : batchedTokens; - gemmData.mProblemDimensions.mBatchedN = - mOptions.transposeMmaOutput ? batchedTokens : std::vector{}; - gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m; - gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n; - gemmData.mProblemDimensions.mK = k; - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - gemmData.mProblemDimensions.mRank = 0; - gemmData.mProblemDimensions.mWorldSize = 1; - gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; - gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; - gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; - gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; - - auto const& config = configs[configIndex]; + int32_t const clusterSizeInBatchDim = getClusterSizeInBatchDim(config.mOptions); + setProblemDimensions(gemmData, config.mOptions.mTransposeMmaOutput, m, n, k, batchedTokens, + numTokens, numBatches, maxNumCgasInBatchDim, clusterSizeInBatchDim); return bmm.isValidConfig(config, gemmData); } diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 64fece5021..8fde9f1846 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -106,6 +106,32 @@ std::set computeSelectedTileN(std::vector const& supported_til return selected_tile_nums; } +template +// Select a launcher +LauncherType& get_launcher( + std::unordered_map>& launchers_map, + std::set const& selected_tile_nums, int64_t& tile_N, int64_t& config, + char const* op_name) { + FLASHINFER_CHECK(!selected_tile_nums.empty(), op_name, ": no available tile_N candidates"); + + if (tile_N == -1) { + tile_N = *selected_tile_nums.begin(); + } + + auto it = launchers_map.find(static_cast(tile_N)); + if (it == launchers_map.end()) { + auto const requestedTileN = tile_N; + tile_N = *selected_tile_nums.begin(); + config = -1; + it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(it != launchers_map.end(), op_name, ": failed to select launcher for tile_N ", + tile_N, " after fallback from requested tile_N ", requestedTileN, + " (selected_tile_count=", selected_tile_nums.size(), ")"); + } + + return *(it->second); +} + class FusedMoeLauncher { protected: Optional routing_logits; @@ -255,8 +281,11 @@ class FusedMoeLauncher { Tensor cta_idx_xy_to_mn_limit; Tensor num_non_exiting_ctas; - void prepare_routing_common() { - // Allocate routing phase workspace tensors + void prepare_routing_common(int32_t clusterSize) { + // Allocate routing phase workspace tensors. + // tile_tokens_dim is cluster-level tile (tile * clusterSize) for config matching. + // The routing kernel internally expands cluster-level tiles to CTA-level output entries. + num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device()); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( @@ -281,8 +310,10 @@ class FusedMoeLauncher { // and max number of experts hidden_states.device()); - int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( + // Buffer sized at CTA count: routing kernel expands CGA entries to CTA entries + int32_t max_num_cgas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCgasInBatchDim( args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + int32_t max_num_ctas = max_num_cgas * clusterSize; cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); @@ -292,7 +323,6 @@ class FusedMoeLauncher { workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = max_num_padded_tokens; - workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); workspace.expanded_idx_to_permuted_idx = @@ -319,7 +349,7 @@ class FusedMoeLauncher { int64_t moe_tactic{-1}; std::unique_ptr moe_runner; - void prepare_moe_common(int64_t& moe_tactic) { + void instantiate_moe_runner(int64_t& moe_tactic) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; // For FP8 block-scale (E4m3 activations, E4m3 weights) with DeepSeek FP8, use the // weights-only Runner constructor to match the original kernel path and numerics. @@ -358,7 +388,7 @@ class FusedMoeLauncher { public: virtual void check_routing() const = 0; - virtual void prepare_routing() = 0; + virtual void prepare_routing(int32_t clusterSize) = 0; virtual void check_moe() const = 0; virtual void prepare_moe(int64_t& moe_tactic) = 0; @@ -369,10 +399,14 @@ class FusedMoeLauncher { bool use_routing_scales_on_input = false, bool use_deep_seek_fp8 = false) { check_routing(); - prepare_routing(); + // Runner dictates contract of routing table; must instantiate runner before prepare_routing + instantiate_moe_runner(moe_tactic); + int32_t clusterSize = moe_runner->getConfigClusterSizeInBatchDim(moe_tactic); + prepare_routing(clusterSize); // Execute routing - tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim, + clusterSize); cudaStream_t routing_stream = get_stream(hidden_states.device()); routing_runner.run( @@ -476,8 +510,8 @@ class Bf16MoeLauncher : public FusedMoeLauncher { // TODO n_group, topk_group validation? } - void prepare_routing() override { - FusedMoeLauncher::prepare_routing_common(); + void prepare_routing(int32_t clusterSize) override { + FusedMoeLauncher::prepare_routing_common(clusterSize); args->mDtypeElt = btg::Dtype::Bfloat16; args->mUseDeepSeekFp8 = false; @@ -517,8 +551,6 @@ class Bf16MoeLauncher : public FusedMoeLauncher { } void prepare_moe(int64_t& moe_tactic) override { - FusedMoeLauncher::prepare_moe_common(moe_tactic); - int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; gemm1_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, hidden_states.device()); @@ -617,8 +649,8 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { void check_routing() const override { FusedMoeLauncher::check_routing_common(); } - void prepare_routing() override { - FusedMoeLauncher::prepare_routing_common(); + void prepare_routing(int32_t clusterSize) override { + FusedMoeLauncher::prepare_routing_common(clusterSize); auto dtype = hidden_states.dtype(); if (dtype == dl_float16) { @@ -687,8 +719,6 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { } void prepare_moe(int64_t& moe_tactic) override { - FusedMoeLauncher::prepare_moe_common(moe_tactic); - int32_t max_num_padded_tokens_gemm1 = workspace.total_max_padded_tokens + args->num_experts; int32_t max_num_padded_tokens_gemm2 = workspace.total_max_padded_tokens; @@ -881,8 +911,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { << "num_experts must be greater or equal to local_num_experts + local_expert_offset"; } - void prepare_routing() override { - FusedMoeLauncher::prepare_routing_common(); + void prepare_routing(int32_t clusterSize) override { + FusedMoeLauncher::prepare_routing_common(clusterSize); auto dtype = hidden_states.dtype(); if (dtype == dl_float16) { @@ -995,8 +1025,6 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { } void prepare_moe(int64_t& moe_tactic) override { - FusedMoeLauncher::prepare_moe_common(moe_tactic); - // Calculate max_num_padded_tokens for gemm1 and gemm2 using maybeGetMinTokenCount int32_t max_num_padded_tokens_gemm1 = tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( @@ -1071,10 +1099,17 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { bool use_routing_scales_on_input = false, bool use_deep_seek_fp8 = false) override { check_routing(); - prepare_routing(); + // Set DeepSeek mode before instantiating the runner so the correct + // constructor (weights-only vs act+weights) is chosen. + args->mUseDeepSeekFp8 = quantization_type == Fp8QuantizationType::DeepSeekFp8; + // Runner dictates contract of routing table; must instantiate runner before prepare_routing + instantiate_moe_runner(moe_tactic); + int32_t clusterSize = moe_runner->getConfigClusterSizeInBatchDim(moe_tactic); + prepare_routing(clusterSize); cudaStream_t routing_stream = get_stream(hidden_states.device()); - tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim, + clusterSize); // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; @@ -1178,8 +1213,8 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { void check_routing() const override { FusedMoeLauncher::check_routing_common(); } - void prepare_routing() override { - FusedMoeLauncher::prepare_routing_common(); + void prepare_routing(int32_t clusterSize) override { + FusedMoeLauncher::prepare_routing_common(clusterSize); args->mDtypeElt = mDtypeAct; args->mUseDeepSeekFp8 = false; @@ -1224,8 +1259,6 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { args->output1_scales_gate_scalar = nullptr; args->output2_scales_scalar = nullptr; - FusedMoeLauncher::prepare_moe_common(moe_tactic); - max_num_padded_tokens_gemm1 = tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( workspace.total_max_padded_tokens, args->intermediate_size, @@ -1363,7 +1396,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { FusedMoeLauncher::check_routing_common(); } - void prepare_routing() override { + void prepare_routing(int32_t clusterSize) override { num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device()); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( @@ -1379,15 +1412,15 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { expert_count_histogram = alloc_tensor({size_of_expert_count_histogram}, dl_int32, hidden_states.device()); - int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( + int32_t max_num_cgas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCgasInBatchDim( args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + int32_t max_num_ctas = max_num_cgas * clusterSize; cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); num_non_exiting_ctas = alloc_tensor({1}, dl_int32, hidden_states.device()); workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = max_num_padded_tokens; - workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(const_cast(expert_indices.data_ptr())); workspace.expert_weights = const_cast(expert_weights.data_ptr()); @@ -1476,8 +1509,6 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { ? static_cast(output2_scales_scalar.value().data_ptr()) : nullptr; - FusedMoeLauncher::prepare_moe_common(moe_tactic); - auto const sf_vec_size = mDtypeWeights == btg::Dtype::MxE2m1 ? 32 : 16; max_num_padded_tokens_gemm1 = @@ -1537,10 +1568,14 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { bool use_routing_scales_on_input = false, bool use_deep_seek_fp8 = false) override { check_routing(); - prepare_routing(); + // Runner dictates contract of routing table; must instantiate runner before prepare_routing + instantiate_moe_runner(moe_tactic); + int32_t clusterSize = moe_runner->getConfigClusterSizeInBatchDim(moe_tactic); + prepare_routing(clusterSize); // Execute routing - tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim, + clusterSize); cudaStream_t routing_stream = get_stream(hidden_states.device()); routing_runner.run( @@ -1670,16 +1705,12 @@ Array trtllm_bf16_moe(Optional const& routing_logits, int64_t tile_N = moe_tactic[0]; int64_t config = moe_tactic[1]; - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - } - - // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + // Choose or fall back to default launcher for the given tile_N + auto& selected_launcher = + get_launcher(launchers_map, selected_tile_nums, tile_N, config, "trtllm_bf16_moe"); // Run the launcher - it will create its own runner internally - return selected_launcher->run(config, enable_pdl); + return selected_launcher.run(config, enable_pdl); } Array trtllm_fp8_per_tensor_scale_moe( @@ -1762,16 +1793,12 @@ Array trtllm_fp8_per_tensor_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - } - - // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + // Select a launcher + auto& selected_launcher = get_launcher(launchers_map, selected_tile_nums, tile_N, config, + "trtllm_fp8_per_tensor_scale_moe"); // Run the launcher - it will create its own runner internally - return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input); + return selected_launcher.run(config, enable_pdl, use_routing_scales_on_input); } Array trtllm_fp8_block_scale_moe( @@ -1883,16 +1910,12 @@ Array trtllm_fp8_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - } - - // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + // Select a launcher + auto& selected_launcher = + get_launcher(launchers_map, selected_tile_nums, tile_N, config, "trtllm_fp8_block_scale_moe"); // Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally - return selected_launcher->run( + return selected_launcher.run( config, enable_pdl, false /* use_routing_scales_on_input */, quantization_type == Fp8QuantizationType::DeepSeekFp8 /* use_deep_seek_fp8 */); } @@ -2027,17 +2050,12 @@ Array trtllm_fp4_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - config = -1; // Let the runner choose default - } - - // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + // Select a launcher + auto& selected_launcher = + get_launcher(launchers_map, selected_tile_nums, tile_N, config, "trtllm_fp4_block_scale_moe"); // Run the launcher - it will create its own runner internally - return selected_launcher->run(config, enable_pdl); + return selected_launcher.run(config, enable_pdl); } Array trtllm_mxint4_block_scale_moe( @@ -2116,17 +2134,12 @@ Array trtllm_mxint4_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - config = -1; // Let the runner choose default - } - - // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + // Select a launcher + auto& selected_launcher = get_launcher(launchers_map, selected_tile_nums, tile_N, config, + "trtllm_mxint4_block_scale_moe"); // Run the launcher - it will create its own runner internally - return selected_launcher->run(config, enable_pdl); + return selected_launcher.run(config, enable_pdl); } Array> trtllm_get_valid_moe_configs( diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 5408d2d059..36073cd1bd 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -413,6 +413,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } else { numCta = divUpTileN(count, params.mTileTokensDim); } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCta *= params.mClusterSizeInBatchDim; int32_t ctaOffset; int32_t numNonExitingCtas; @@ -422,30 +424,34 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize int32_t mnLimit1; int32_t mnLimit2; if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffset + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, ctaPaddingLog2) + count; } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffset + cta + 1) * ctaTile; + mnLimit2 = ctaOffset * ctaTile + count; } params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } - // get the padded offset associated with this expert + // get the padded offset associated with this expert (token-space, CGA granularity) int32_t offset; if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); + offset = mulLog2(ctaOffset >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); + offset = (ctaOffset / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + permutedIdxSize = + mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } // write out padded count diff --git a/csrc/trtllm_fused_moe_routing_llama4.cu b/csrc/trtllm_fused_moe_routing_llama4.cu index 31674e0a8e..654270654d 100644 --- a/csrc/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/trtllm_fused_moe_routing_llama4.cu @@ -197,6 +197,8 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } numCta += num; } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCta *= params.mClusterSizeInBatchDim; // second, we perform the exclusive sum across the warp int32_t ctaOffset; int32_t numNonExitingCtas; @@ -214,19 +216,23 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } else { finalNumCta = divUpTileN(count, params.mTileTokensDim); } + finalNumCta *= params.mClusterSizeInBatchDim; auto expertIdx = threadIdx.x * ExpertsPerThread + ii; // during the scan for expert offsets, we can already write out // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize int32_t mnLimit1; int32_t mnLimit2; if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffsetExp, params.mPaddingLog2) + count; + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffsetExp, ctaPaddingLog2) + count; } else { - mnLimit1 = mulTileN(ctaOffsetExp + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffsetExp, params.mTileTokensDim) + count; + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffsetExp + cta + 1) * ctaTile; + mnLimit2 = ctaOffsetExp * ctaTile + count; } params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = min(mnLimit1, mnLimit2); } @@ -237,9 +243,10 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam if (cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + permutedIdxSize = + mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; @@ -259,10 +266,12 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; int32_t finalExpertOffset[ExpertsPerThread]; + // Convert CTA-level ctaOffset back to token-space (CGA granularity) if constexpr (KernelParams::isPow2) { - finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); + finalExpertOffset[0] = + mulLog2(ctaOffset >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - finalExpertOffset[0] = mulTileN(ctaOffset, params.mTileTokensDim); + finalExpertOffset[0] = (ctaOffset / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } #pragma unroll for (int ii = 1; ii < ExpertsPerThread; ++ii) { diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu index 364c267c00..584b13ede6 100644 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ b/csrc/trtllm_fused_moe_routing_renormalize.cu @@ -183,10 +183,13 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } else { numCta = divUpTileN(accExpertCount, params.mTileTokensDim); } + // Expand from CGA count to CTA count to keep the semantic stable with downstream kernels + numCta *= params.mClusterSizeInBatchDim; int32_t ctaOffset = 0; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + // Second scan for token-space padded offsets stays at CGA granularity int32_t expertScanCounts = 0; int32_t tmpCount; if constexpr (KernelParams::isPow2) { @@ -202,14 +205,17 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize int32_t mnLimit1; int32_t mnLimit2; if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount; + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffset + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, ctaPaddingLog2) + accExpertCount; } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + accExpertCount; + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffset + cta + 1) * ctaTile; + mnLimit2 = ctaOffset * ctaTile + accExpertCount; } params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } @@ -219,9 +225,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) if (threadIdx.x == 0) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + permutedIdxSize = + mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index af48040d0a..a24b40a2e8 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "flashinfer/exception.h" @@ -47,7 +48,8 @@ inline int32_t computeLog2(int32_t val, std::string const& name = "") { Runner::Runner() {} -Runner::Runner(int32_t tileTokensDim) : mTileTokensDim(tileTokensDim) {} +Runner::Runner(int32_t tileTokensDim, int32_t clusterSizeInBatchDim) + : mTileTokensDim(tileTokensDim), mClusterSizeInBatchDim(clusterSizeInBatchDim) {} void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, int32_t topK, int32_t nGroup, int32_t topkGroup, int32_t localExpertOffset, @@ -93,6 +95,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); routingData.mTileTokensDim = mTileTokensDim; + routingData.mClusterSizeInBatchDim = mClusterSizeInBatchDim; + routingData.mClusterSizeLog2 = computeLog2(mClusterSizeInBatchDim); routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; @@ -129,6 +133,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); routingData.mTileTokensDim = mTileTokensDim; + routingData.mClusterSizeInBatchDim = mClusterSizeInBatchDim; + routingData.mClusterSizeLog2 = computeLog2(mClusterSizeInBatchDim); routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; @@ -177,6 +183,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mTopK = topK; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); routingData.mTileTokensDim = mTileTokensDim; + routingData.mClusterSizeInBatchDim = mClusterSizeInBatchDim; + routingData.mClusterSizeLog2 = computeLog2(mClusterSizeInBatchDim); routingData.mLocalExpertsStartIdx = localExpertOffset; routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; @@ -236,16 +244,14 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( if (isGatedAct) { ActType actType = activationTypeToGatedActType(activationType); tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { - // Swap A and B dtypes because transposeMmaOutput is hardcoded to true - .dtypeA = dtypeWeights, - .dtypeB = dtypeAct, + .dtypeA = dtypeAct, + .dtypeB = dtypeWeights, .dtypeC = dtypeAct, .actType = actType, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = !useDeepSeekFp8, .routeAct = true, .staticBatch = false, - .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, .useShuffledMatrix = useShuffledMatrix, @@ -254,16 +260,14 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( } else { EltwiseActType actType = activationTypeToEltwiseActType(activationType); tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { - // Swap A and B dtypes because transposeMmaOutput is hardcoded to true - .dtypeA = dtypeWeights, - .dtypeB = dtypeAct, + .dtypeA = dtypeAct, + .dtypeB = dtypeWeights, .dtypeC = dtypeAct, .eltwiseActType = actType, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = false, .routeAct = true, .staticBatch = false, - .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = 128, .useShuffledMatrix = useShuffledMatrix, @@ -292,11 +296,11 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { - auto maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, + numExperts, maxNumCgasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, @@ -307,35 +311,35 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t configIndex) const { - auto maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, configIndex); + maxNumCgasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const { - auto maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); return mRunner.getDefaultValidConfigIndex(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim); + maxNumCgasInBatchDim); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const { - auto maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); auto const isValid = mRunner.isValidConfigIndex(configIndex, numTokens, intermediateSizeFactor * intermediateSize, - hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); + hiddenSize, {}, numTokens, numExperts, maxNumCgasInBatchDim); return isValid; } @@ -343,6 +347,10 @@ bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hidde std::vector Runner::getPassingConfigIndices() const { return mRunner.getPassingConfigIndices(); } + +int32_t Runner::getConfigClusterSizeInBatchDim(int32_t configIndex) const { + return mRunner.getConfigClusterSizeInBatchDim(configIndex); +} } // namespace PermuteGemm1 namespace Gemm2 { @@ -350,16 +358,14 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, int32_t tileTokensDim, bool useDeepSeekFp8, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { - // Swap A and B dtypes because transposeMmaOutput is hardcoded to true - .dtypeA = dtypeWeights, - .dtypeB = dtypeAct, + .dtypeA = dtypeAct, + .dtypeB = dtypeWeights, .dtypeC = dtypeOut, .eltwiseActType = EltwiseActType::None, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = false, .routeAct = false, .staticBatch = false, - .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, .useShuffledMatrix = useShuffledMatrix, @@ -385,10 +391,10 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm2Workspace, int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { - auto maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run( - numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, + numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCgasInBatchDim, permutedHiddenState, permutedHiddenStateScale, weights, weightsScale, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, ptrBias, @@ -401,30 +407,30 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t configIndex) const { - auto maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getWorkspaceSizeInBytes(numTokens, hiddenSize, intermediateSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim, configIndex); + numExperts, maxNumCgasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const { - auto maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getDefaultValidConfigIndex(numTokens, hiddenSize, intermediateSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim); + numExperts, maxNumCgasInBatchDim); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens) const { - auto const maxNumCtasInBatchDim = - Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto const maxNumCgasInBatchDim = + Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); auto const isValid = mRunner.isValidConfigIndex(configIndex, numTokens, hiddenSize, intermediateSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim); + numTokens, numExperts, maxNumCgasInBatchDim); return isValid; } @@ -432,6 +438,10 @@ bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hidde std::vector Runner::getPassingConfigIndices() const { return mRunner.getPassingConfigIndices(); } + +int32_t Runner::getConfigClusterSizeInBatchDim(int32_t configIndex) const { + return mRunner.getConfigClusterSizeInBatchDim(configIndex); +} } // namespace Gemm2 namespace MoE { @@ -449,7 +459,15 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8 mPassingConfigs.reserve(totalPassingIndices); for (auto const& indexGemm1 : gemm1PassingIndices) { + int32_t const gemm1ClusterSize = mPermuteGemm1.getConfigClusterSizeInBatchDim(indexGemm1); for (auto const& indexGemm2 : gemm2PassingIndices) { + int32_t const gemm2ClusterSize = mGemm2.getConfigClusterSizeInBatchDim(indexGemm2); + // Routing emits one shared CTA table for both GEMMs, so FC1 and FC2 must agree on the + // batch-dimension cluster size & tile size. Otherwise ctaIdxXyToMnLimit/numNonExitingCtas are + // generated at one CTA granularity while one GEMM still consumes a different batch tile size. + if (gemm1ClusterSize != gemm2ClusterSize) { + continue; + } mPassingConfigs.push_back(MoEConfig{indexGemm1, indexGemm2}); } } @@ -571,6 +589,19 @@ int64_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, return std::distance(mPassingConfigs.begin(), it); } +int32_t Runner::getConfigClusterSizeInBatchDim(int64_t configIndex) const { + FLASHINFER_CHECK(configIndex >= 0 && configIndex < static_cast(mPassingConfigs.size()), + "Invalid MoE config index ", configIndex, ". Valid range is [0, ", + static_cast(mPassingConfigs.size()) - 1, "]."); + auto const& config = mPassingConfigs[configIndex]; + int32_t const gemm1ClusterSize = mPermuteGemm1.getConfigClusterSizeInBatchDim(config.gemm1Config); + int32_t const gemm2ClusterSize = mGemm2.getConfigClusterSizeInBatchDim(config.gemm2Config); + FLASHINFER_CHECK(gemm1ClusterSize == gemm2ClusterSize, + "Incompatible MoE config pair: gemm1 clusterSizeInBatchDim=", gemm1ClusterSize, + ", gemm2 clusterSizeInBatchDim=", gemm2ClusterSize, "."); + return gemm1ClusterSize; +} + void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream, int64_t configIndex, bool enable_pdl) { FLASHINFER_CHECK(configIndex >= 0 && configIndex < static_cast(mPassingConfigs.size()), diff --git a/include/flashinfer/exception.h b/include/flashinfer/exception.h index aaaa2b5b3e..62d8422062 100644 --- a/include/flashinfer/exception.h +++ b/include/flashinfer/exception.h @@ -70,6 +70,17 @@ void write_to_stream(std::ostringstream& oss, T&& val, Args&&... args) { flashinfer::Warning(__FUNCTION__, __FILE__, __LINE__, msg).emit(); \ } while (0) +#define FLASHINFER_LOG(...) \ + do { \ + std::ostringstream oss; \ + write_to_stream(oss, ##__VA_ARGS__); \ + std::string msg = oss.str(); \ + if (msg.empty()) { \ + msg = "Log triggered"; \ + } \ + flashinfer::Log(__FUNCTION__, __FILE__, __LINE__, msg).emit(); \ + } while (0) + namespace flashinfer { class Error : public std::exception { private: @@ -101,6 +112,21 @@ class Warning { void emit() const { std::cerr << message_ << std::endl; } }; +class Log { + private: + std::string message_; + + public: + Log(const std::string& func, const std::string& file, int line, const std::string& message) { + std::ostringstream oss; + oss << "Log in function '" << func << "' " + << "at " << file << ":" << line << ": " << message; + message_ = oss.str(); + } + + void emit() const { std::cerr << message_ << std::endl; } +}; + } // namespace flashinfer #endif // FLASHINFER_EXCEPTION_H_ diff --git a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h index 64e958d9e8..f9d9929437 100644 --- a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h +++ b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h @@ -63,8 +63,11 @@ enum class EltwiseActType { }; struct TrtllmGenBatchedGemmRunnerOptions { + // Canonically, A is activation. batchedGemm::trtllm::gen::Dtype dtypeA; + // B is weight. batchedGemm::trtllm::gen::Dtype dtypeB; + // C is output. batchedGemm::trtllm::gen::Dtype dtypeC; ActType actType{ActType::SwiGlu}; EltwiseActType eltwiseActType{EltwiseActType::None}; @@ -72,6 +75,7 @@ struct TrtllmGenBatchedGemmRunnerOptions { bool fusedAct{false}; bool routeAct{false}; bool staticBatch{false}; + // If transposeMmaOutput is true, then A and B are swapped under the hood. bool transposeMmaOutput{false}; int32_t tileSize{8}; int32_t epilogueTileM{128}; @@ -121,6 +125,9 @@ class TrtllmGenBatchedGemmRunner { return mPassingConfigIndices; } + // Get the cluster size in the batch dimension for the given config index + [[nodiscard]] int32_t getConfigClusterSizeInBatchDim(int32_t configIndex) const; + // Get the list of config indices that are valid for the given problem shape [[nodiscard]] std::vector getValidConfigIndices( int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh index 17143ab8a4..01e682c963 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh @@ -324,6 +324,7 @@ __device__ void routingPermutation(KernelParams params, } else { numCta = divUpTileN(count, params.mTileTokensDim); } + numCta *= params.mClusterSizeInBatchDim; int32_t ctaOffset; int32_t numNonExitingCtas; @@ -335,24 +336,27 @@ __device__ void routingPermutation(KernelParams params, const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize int32_t mnLimit1; int32_t mnLimit2; if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffset + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, ctaPaddingLog2) + count; } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffset + cta + 1) * ctaTile; + mnLimit2 = ctaOffset * ctaTile + count; } params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } - // get the padded offset associated with this expert + // get the padded offset associated with this expert (token-space, CGA granularity) int32_t offset; if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); + offset = mulLog2(ctaOffset >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); + offset = (ctaOffset / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } // write expert offsets to shared smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; @@ -362,9 +366,10 @@ __device__ void routingPermutation(KernelParams params, if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + permutedIdxSize = + mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; @@ -557,24 +562,24 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) // Compute the runtime config for projections // Whether or not an expert is local is taken into account when the histogram is computed // so we do not need to take it into account here. - // const int32_t numCta = divUpLog2(count, params.mPaddingLog2); int32_t numCta; if constexpr (KernelParams::isPow2) { numCta = divUpLog2(count, params.mPaddingLog2); } else { numCta = divUpTileN(count, params.mTileTokensDim); } + numCta *= params.mClusterSizeInBatchDim; int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); if (threadIdx.x < params.mNumExperts) { - // Get the padded offset associated with this expert + // Get the padded offset associated with this expert (token-space, CGA granularity) int32_t offset; if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); + offset = mulLog2(ctaOffset >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); + offset = (ctaOffset / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } // Write expert offsets to shared @@ -589,9 +594,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + permutedIdxSize = + mulLog2(numNonExitingCtas >> params.mClusterSizeLog2, params.mPaddingLog2); } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + permutedIdxSize = (numNonExitingCtas / params.mClusterSizeInBatchDim) * params.mTileTokensDim; } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; @@ -603,14 +609,17 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + // Write CTA-level MnLimits using ctaTile = cgaTile / clusterSize int32_t mnLimit1; int32_t mnLimit2; if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + int32_t ctaPaddingLog2 = params.mPaddingLog2 - params.mClusterSizeLog2; + mnLimit1 = mulLog2(ctaOffset + cta + 1, ctaPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, ctaPaddingLog2) + count; } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + int32_t ctaTile = params.mTileTokensDim / params.mClusterSizeInBatchDim; + mnLimit1 = (ctaOffset + cta + 1) * ctaTile; + mnLimit2 = ctaOffset * ctaTile + count; } params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index ba90742ce0..8db44c6930 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -95,8 +95,14 @@ struct DataBase { int32_t mNumTokens; int32_t mNumExperts; int32_t mTopK; - int32_t mPaddingLog2; + // Cluster-wide tile size in token dimension. int32_t mTileTokensDim; + // log2() of the padding size in cluster-wide tile. + int32_t mPaddingLog2; + // Cluster size (e.g., 1x2, 2x1, etc.) in batch dimension. + int32_t mClusterSizeInBatchDim{1}; + // log2() of the cluster size in batch dimension. + int32_t mClusterSizeLog2{0}; /// For expert parallelization int32_t mLocalExpertsStartIdx; @@ -131,6 +137,8 @@ struct KernelParamsBase { int32_t mPaddingLog2 = -1; int32_t mTileTokensDim = 0; + int32_t mClusterSizeInBatchDim = 1; + int32_t mClusterSizeLog2 = 0; int32_t mLocalExpertsStartIdx = 0; int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; @@ -155,6 +163,8 @@ struct KernelParamsBase { mPaddingLog2 = data.mPaddingLog2; mTileTokensDim = data.mTileTokensDim; + mClusterSizeInBatchDim = data.mClusterSizeInBatchDim; + mClusterSizeLog2 = data.mClusterSizeLog2; mLocalExpertsStartIdx = data.mLocalExpertsStartIdx; mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2; mNumLocalExperts = data.mNumLocalExperts; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 46617e5dbd..c7f3fe250f 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -78,18 +78,18 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod }; } -inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, - int32_t tileTokensDim) { - // For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime. - // We launch maximally possible number of CTAs and use ptrNumNonExitingCtas to determine - // the actual number of CTAs to run. +inline int32_t getMaxNumCgasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, + int32_t cgaTileTokensDim) { + // For MoE, mNumTokens != 0 and the number of CGAs is known only at runtime. + // We launch maximally possible number of CGAs and use ptrNumNonExitingCtas to determine + // the actual number of CGAs to run. // Initialize number of tokens with the number of expanded tokens after routing. - int32_t numRemainingTokens = numTokens * topK; - int32_t maxNumCtasInBatchDim = 0; + auto numRemainingTokens = numTokens * topK; + int32_t maxNumCgasInBatchDim = 0; // First, distribute one token each expert until token depletion to maximize CTA tile count. - int32_t numExpertsFilled = std::min(numExperts, numRemainingTokens); - maxNumCtasInBatchDim += numExpertsFilled; + auto numExpertsFilled = std::min(numExperts, numRemainingTokens); + maxNumCgasInBatchDim += numExpertsFilled; numRemainingTokens -= numExpertsFilled; // Next, greedily pour all remaining tokens to one expert to maximize CTA tile count. // E.g., at this point tokens over 4 experts are [1, 1, 1, 1], and we have 4 tokens left. @@ -100,24 +100,29 @@ inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t // capacity. These buckets, if full, can then be attributed to any expert; it does not have to // belong to the same expert every time. if (numRemainingTokens > 0) { - // For every tileTokenDim tokens, we add an extra CTA tile in the token dimension. - // The number of CTA tiles is given by divDown(numRemainingTokens, tokenTileDim). - maxNumCtasInBatchDim += (numRemainingTokens / tileTokensDim); + // For every tileTokenDim tokens, we add an extra CGA tile in the token dimension. + // The number of CGA tiles is given by divDown(numRemainingTokens, tokenTileDim). + maxNumCgasInBatchDim += (numRemainingTokens / cgaTileTokensDim); } - return maxNumCtasInBatchDim; + return maxNumCgasInBatchDim; +} + +inline int32_t getCgaSizeInBatchDim(bool transposeMmaOutput, int32_t clusterDimX, + int32_t clusterDimY) { + return transposeMmaOutput ? clusterDimY : clusterDimX; } inline int32_t getMaxPermutedPaddedCount(int32_t numTokens, int32_t expertsPerToken, int32_t numExperts, int32_t padding) { - int32_t maxCtas = getMaxNumCtasInBatchDim(numTokens, expertsPerToken, numExperts, padding); - return maxCtas * padding; + int32_t maxCgas = getMaxNumCgasInBatchDim(numTokens, expertsPerToken, numExperts, padding); + return maxCgas * padding; } class Runner { public: explicit Runner(); - explicit Runner(int32_t tileTokensDim); + explicit Runner(int32_t tileTokensDim, int32_t clusterSizeInBatchDim = 1); void run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, int32_t topK, int32_t nGroups, int32_t topkGroups, int32_t localExpertOffset, @@ -132,6 +137,7 @@ class Runner { private: int32_t mTileTokensDim{8}; + int32_t mClusterSizeInBatchDim{1}; }; } // namespace Routing @@ -201,6 +207,8 @@ class Runner { [[nodiscard]] std::vector getPassingConfigIndices() const; + [[nodiscard]] int32_t getConfigClusterSizeInBatchDim(int32_t configIndex) const; + void run(void* hiddenState, void* hiddenStateScale, void* weight, void* weightScale, void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar, float* ptrBias, float* ptrGatedActAlpha, float* ptrGatedActBeta, float* ptrClampLimit, @@ -242,6 +250,8 @@ class Runner { [[nodiscard]] std::vector getPassingConfigIndices() const; + [[nodiscard]] int32_t getConfigClusterSizeInBatchDim(int32_t configIndex) const; + void run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weight, void* weightScale, float* outputScalesScalar, float* ptrBias, void* output, void* outputScale, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -348,7 +358,6 @@ struct MoEWorkspace { float* permuted_hidden_states_scale = nullptr; // Gemm1 intermediate outputs: - int32_t ProjUpTileN{0}; void* gemm1_output = nullptr; float* gemm1_output_scale = nullptr; @@ -404,6 +413,8 @@ class Runner { int32_t numLocalExperts, int32_t numTokens) const; + [[nodiscard]] int32_t getConfigClusterSizeInBatchDim(int64_t configIndex) const; + private: void setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, moe::dev::convertsf::Data& convertSfData,