From bc4ed4e8ae4d008c1857598c4cfd39246a1d3b4c Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Fri, 14 Nov 2025 03:36:58 -0800 Subject: [PATCH 1/4] wip --- csrc/trtllm_fused_moe_dev_kernel.cu | 205 +++++++++++++++++- csrc/trtllm_fused_moe_runner.cu | 34 +-- flashinfer/jit/fused_moe.py | 1 + .../flashinfer/trtllm/fused_moe/DevKernel.h | 28 ++- 4 files changed, 240 insertions(+), 28 deletions(-) diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index 9a51384090..0fe088a174 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -672,11 +672,209 @@ __device__ float4 vectorizedLoadPtx(float4 const* ptr) { // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. +constexpr int MaxTopK = 64; + +typedef struct __CUDA_ALIGN__(4) { + cutlass::bfloat16_t array[2]; +} __nv_bfloat16_2; + +typedef struct __CUDA_ALIGN__(8) { + cutlass::bfloat16_t array[4]; +} __nv_bfloat16_4; + +typedef struct __CUDA_ALIGN__(8) { + half array[4]; +} half_4; + +template +struct ScaleTraitsStruct; + +template<> +struct ScaleTraitsStruct<1, cutlass::bfloat16_t> { + using PackedType = cutlass::bfloat16_t; + using ArrayType = cutlass::Array<__nv_bfloat16, 1>; +}; + +template<> +struct ScaleTraitsStruct<2, cutlass::bfloat16_t> { + using PackedType = __nv_bfloat16_2; + using ArrayType = cutlass::Array<__nv_bfloat16, 2>; +}; + +template<> +struct ScaleTraitsStruct<4, cutlass::bfloat16_t> { + using PackedType = __nv_bfloat16_4; + using ArrayType = cutlass::Array; +}; + +template<> +struct ScaleTraitsStruct<1, float> { + using PackedType = float; + using ArrayType = cutlass::Array; +}; + +template<> +struct ScaleTraitsStruct<2, float> { + using PackedType = float2; + using ArrayType = cutlass::Array; +}; + +template<> +struct ScaleTraitsStruct<4, float> { + using PackedType = float4; + using ArrayType = cutlass::Array; +}; + +template<> +struct ScaleTraitsStruct<1, half> { + using PackedType = half; + using ArrayType = cutlass::Array; +}; + +template<> +struct ScaleTraitsStruct<2, half> { + using PackedType = half2; + using ArrayType = cutlass::Array; +}; + +template<> +struct ScaleTraitsStruct<4, half> { + using PackedType = half_4; + using ArrayType = cutlass::Array; +}; + +template +struct FinalizeTraits; + +template +struct FinalizeTraits<1, TypeExpW_> { + using IdxPackedType = int; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<1, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<2, TypeExpW_> { + using IdxPackedType = int2; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<2, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<4, TypeExpW_> { + using IdxPackedType = int4; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<4, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; template __global__ void finalizeKernelVecLoad(KernelParams params) { using Type = typename KernelParams::Type; using TypeExpW = typename KernelParams::TypeExpW; + int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor; + + static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4, "TopKUnrollFactor must be 1, 2, or 4"); + using FinalizeTraits = FinalizeTraits; + using IdxPackedType = typename FinalizeTraits::IdxPackedType; + using IdxArrayType = typename FinalizeTraits::IdxArrayType; + using ScalePackedType = typename FinalizeTraits::ScalePackedType; + using ScaleArrayType = typename FinalizeTraits::ScaleArrayType; + + int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits::value; + int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits::value; + assert(hiddenDimPaddedBits % 128 == 0); + assert(hiddenDimBits % 128 == 0); + + // Load 128-bits per thread, according to the smallest data type we read/write + constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / cutlass::sizeof_bits::value; + using InputElem = cutlass::Array; + using OutputElem = cutlass::Array; + using ComputeElem = cutlass::Array; + + int64_t const tokenIdx = blockIdx.x; + int64_t const startOffset = threadIdx.x; + int64_t const stride = FINALIZE_THREADS_PER_BLOCK; + int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD; + int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD; + + __shared__ __nv_bfloat16_4 scaleArrSmem[MaxTopK / TopKUnrollFactor]; + __shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor]; + + for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx += blockDim.x) { + int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor; + auto permutedIdxPacked = reinterpret_cast(params.expandedIdxToPermutedIdx)[expandedIdx/4]; + // auto permutedIdxArr = *reinterpret_cast const*>(&permutedIdxPacked); + auto scalePacked = (params.expertWeightsPtr != nullptr) + ? reinterpret_cast<__nv_bfloat16_4 const*>(params.expertWeightsPtr)[expandedIdx/4] + : __nv_bfloat16_4{cutlass::bfloat16_t(1.f), cutlass::bfloat16_t(1.f), cutlass::bfloat16_t(1.f), cutlass::bfloat16_t(1.f)}; + // auto scaleArr = *reinterpret_cast const*>(&scalePacked); + + scaleArrSmem[kChunkIdx] = scalePacked; + permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked; + } + + auto const offset = tokenIdx * params.hiddenDim; + Type* outputPtr = params.outPtr + offset; + auto* outElemPtr = reinterpret_cast(outputPtr); + auto const* inElemPtr = reinterpret_cast(params.inPtr); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // wait on primary kernel when using PDL + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + #endif + __syncthreads(); + + + for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) { + ComputeElem threadOutput; + threadOutput.fill(0); + for (int kChunkIdx = 0; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx++) { + auto permutedIdxArr = *reinterpret_cast const*>(&permutedIdxArrSmem[kChunkIdx]); + InputElem inputElemArr[TopKUnrollFactor]; + #pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } + + auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; + + float4 input = + vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); + inputElemArr[ki] = *reinterpret_cast(&input); + } + auto scaleArr = *reinterpret_cast(&scaleArrSmem[kChunkIdx]); + auto const scaleFloatArr = arrayConvert>(scaleArr); + + #pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } + + ComputeElem expertResult = arrayConvert(inputElemArr[ki]); + threadOutput = threadOutput + scaleFloatArr[ki] * expertResult; + } + } + OutputElem outputElem = arrayConvert(threadOutput); + outElemPtr[elemIndex] = outputElem; + } +} + +template +__global__ void finalizeKernelVecLoad_baseline(KernelParams params) { + using Type = typename KernelParams::Type; + using TypeExpW = typename KernelParams::TypeExpW; int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits::value; int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits::value; @@ -813,7 +1011,7 @@ void run(Data const& data, void* stream) { int const numBlocksY = std::min(8192, data.numTokens); dim3 numBlocks(numBlocksX, numBlocksY); - LAUNCH_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream); + LAUNCH_TOPK_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream); } else { int const numThreads = 256; int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads; @@ -827,9 +1025,10 @@ void run(Data const& data, void* stream) { // ensure that when the number of waves is greater than 1, we choose to use the kernel with // vectorized loading. dim3 numBlocks(numBlocksX, numBlocksY); - LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); + LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); } else { - LAUNCH_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens, + FLASHINFER_CHECK(data.topK <= MaxTopK, "Finalize kernel with vectorized loading is not supported for this TopK value: %d", data.topK); + LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad_baseline, /*numBlocks=*/data.numTokens, /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); } } diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index b5ff5757c9..99824651b8 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -517,16 +517,16 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d auto const& config = mPassingConfigs[configIndex]; - mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, - args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, - args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, - args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, - workspace.gemm1_output_scale, args.top_k, args.hidden_size, - args.intermediate_size, args.local_num_experts, args.num_tokens, - workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, - workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, - workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, - args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config, enable_pdl); + // mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, + // args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, + // args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, + // args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, + // workspace.gemm1_output_scale, args.top_k, args.hidden_size, + // args.intermediate_size, args.local_num_experts, args.num_tokens, + // workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, + // workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, + // workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, + // args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config, enable_pdl); // We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint. void* gemm2_input = workspace.gemm1_output; @@ -540,13 +540,13 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d } // Run gemm2 - mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, - args.output2_scales_scalar, args.gemm2_bias, workspace.gemm2_output, - workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size, - args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, - workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, - workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, - config.gemm2Config, enable_pdl); + // mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, + // args.output2_scales_scalar, args.gemm2_bias, workspace.gemm2_output, + // workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size, + // args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, + // workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, + // workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, + // config.gemm2Config, enable_pdl); // Run finalize if (args.do_finalize) { diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 78c19e98ac..ddcc8308f9 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -55,6 +55,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec: "-DENABLE_FP8", "-DENABLE_FP4", "-DUSING_OSS_CUTLASS_MOE_GEMM", + "-lineinfo" ] nvcc_flags += current_compilation_context.get_nvcc_flags_list( diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 0ee9ba6fe9..b38a4dc43b 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -116,29 +116,40 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeElt"); \ } -#define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ +#define LAUNCH_EXPW(data, kernel, topK, numBlocks, numThreads, smemSize, stream) \ if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float, topK), kernel, numBlocks, numThreads, \ smemSize, stream); \ } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float, topK), kernel, numBlocks, \ numThreads, smemSize, stream); \ } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float, topK), kernel, numBlocks, numThreads, \ smemSize, stream); \ } else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t, topK), kernel, numBlocks, \ numThreads, smemSize, stream); \ } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t, topK), kernel, \ numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, topK), kernel, \ numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported pair"); \ } +#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.topK % 4 == 0) { \ + LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ + } else if (data.topK % 2 == 0) { \ + LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ + } else if (data.topK % 1 == 0) { \ + LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported topK"); \ + } + #define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ if (data.mPaddingLog2 > 0) { \ LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, \ @@ -453,10 +464,11 @@ struct Data { int32_t const* totalNumPaddedTokens; }; -template +template struct KernelParams { using Type = Type_; using TypeExpW = TypeExpW_; + static constexpr int TopKUnrollFactor = TopKUnrollFactor_; static constexpr bool UsePdl = UsePdl_; Type const* inPtr; From 3e4840262275e440a883d042d5a0ec73f8e0375f Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Fri, 14 Nov 2025 07:49:07 -0800 Subject: [PATCH 2/4] feat: finalize optimization --- csrc/trtllm_fused_moe_dev_kernel.cu | 163 ++++++------------ csrc/trtllm_fused_moe_runner.cu | 34 ++-- .../flashinfer/trtllm/fused_moe/DevKernel.h | 60 +++---- 3 files changed, 103 insertions(+), 154 deletions(-) diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index 0fe088a174..5a17b227b8 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -672,81 +672,87 @@ __device__ float4 vectorizedLoadPtx(float4 const* ptr) { // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + constexpr int MaxTopK = 64; typedef struct __CUDA_ALIGN__(4) { cutlass::bfloat16_t array[2]; -} __nv_bfloat16_2; +} bfloat16_2; typedef struct __CUDA_ALIGN__(8) { cutlass::bfloat16_t array[4]; -} __nv_bfloat16_4; +} bfloat16_4; typedef struct __CUDA_ALIGN__(8) { half array[4]; } half_4; -template +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ScaleTraitsStruct; -template<> +template <> struct ScaleTraitsStruct<1, cutlass::bfloat16_t> { using PackedType = cutlass::bfloat16_t; - using ArrayType = cutlass::Array<__nv_bfloat16, 1>; + using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<2, cutlass::bfloat16_t> { - using PackedType = __nv_bfloat16_2; - using ArrayType = cutlass::Array<__nv_bfloat16, 2>; + using PackedType = bfloat16_2; + using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<4, cutlass::bfloat16_t> { - using PackedType = __nv_bfloat16_4; + using PackedType = bfloat16_4; using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<1, float> { using PackedType = float; using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<2, float> { using PackedType = float2; using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<4, float> { using PackedType = float4; using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<1, half> { using PackedType = half; using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<2, half> { using PackedType = half2; using ArrayType = cutlass::Array; }; -template<> +template <> struct ScaleTraitsStruct<4, half> { using PackedType = half_4; using ArrayType = cutlass::Array; }; -template +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct FinalizeTraits; -template +template struct FinalizeTraits<1, TypeExpW_> { using IdxPackedType = int; using IdxArrayType = cutlass::Array; @@ -755,7 +761,7 @@ struct FinalizeTraits<1, TypeExpW_> { using ScaleArrayType = typename ScaleTraits::ArrayType; }; -template +template struct FinalizeTraits<2, TypeExpW_> { using IdxPackedType = int2; using IdxArrayType = cutlass::Array; @@ -764,7 +770,7 @@ struct FinalizeTraits<2, TypeExpW_> { using ScaleArrayType = typename ScaleTraits::ArrayType; }; -template +template struct FinalizeTraits<4, TypeExpW_> { using IdxPackedType = int4; using IdxArrayType = cutlass::Array; @@ -773,13 +779,16 @@ struct FinalizeTraits<4, TypeExpW_> { using ScaleArrayType = typename ScaleTraits::ArrayType; }; +//////////////////////////////////////////////////////////////////////////////////////////////////// + template __global__ void finalizeKernelVecLoad(KernelParams params) { using Type = typename KernelParams::Type; using TypeExpW = typename KernelParams::TypeExpW; int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor; - static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4, "TopKUnrollFactor must be 1, 2, or 4"); + static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4, + "TopKUnrollFactor must be 1, 2, or 4"); using FinalizeTraits = FinalizeTraits; using IdxPackedType = typename FinalizeTraits::IdxPackedType; using IdxArrayType = typename FinalizeTraits::IdxArrayType; @@ -803,65 +812,65 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD; int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD; - __shared__ __nv_bfloat16_4 scaleArrSmem[MaxTopK / TopKUnrollFactor]; + __shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor]; __shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor]; - for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx += blockDim.x) { + for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor; + kChunkIdx += blockDim.x) { int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor; - auto permutedIdxPacked = reinterpret_cast(params.expandedIdxToPermutedIdx)[expandedIdx/4]; - // auto permutedIdxArr = *reinterpret_cast const*>(&permutedIdxPacked); + auto permutedIdxPacked = reinterpret_cast( + params.expandedIdxToPermutedIdx)[expandedIdx / TopKUnrollFactor]; auto scalePacked = (params.expertWeightsPtr != nullptr) - ? reinterpret_cast<__nv_bfloat16_4 const*>(params.expertWeightsPtr)[expandedIdx/4] - : __nv_bfloat16_4{cutlass::bfloat16_t(1.f), cutlass::bfloat16_t(1.f), cutlass::bfloat16_t(1.f), cutlass::bfloat16_t(1.f)}; - // auto scaleArr = *reinterpret_cast const*>(&scalePacked); + ? reinterpret_cast( + params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] + : ScalePackedType{TypeExpW(1.f)}; scaleArrSmem[kChunkIdx] = scalePacked; permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked; } - + auto const offset = tokenIdx * params.hiddenDim; Type* outputPtr = params.outPtr + offset; auto* outElemPtr = reinterpret_cast(outputPtr); auto const* inElemPtr = reinterpret_cast(params.inPtr); - - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // wait on primary kernel when using PDL if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } - #endif +#endif __syncthreads(); - for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) { ComputeElem threadOutput; threadOutput.fill(0); for (int kChunkIdx = 0; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx++) { - auto permutedIdxArr = *reinterpret_cast const*>(&permutedIdxArrSmem[kChunkIdx]); + auto permutedIdxArr = *reinterpret_cast(&permutedIdxArrSmem[kChunkIdx]); InputElem inputElemArr[TopKUnrollFactor]; - #pragma unroll +#pragma unroll for (int ki = 0; ki < TopKUnrollFactor; ++ki) { auto const permutedIdx = permutedIdxArr[ki]; if (permutedIdx == -1) { continue; } - + auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; - + float4 input = - vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); + vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); inputElemArr[ki] = *reinterpret_cast(&input); } auto scaleArr = *reinterpret_cast(&scaleArrSmem[kChunkIdx]); - auto const scaleFloatArr = arrayConvert>(scaleArr); + auto const scaleFloatArr = + arrayConvert>(scaleArr); - #pragma unroll +#pragma unroll for (int ki = 0; ki < TopKUnrollFactor; ++ki) { auto const permutedIdx = permutedIdxArr[ki]; if (permutedIdx == -1) { continue; } - ComputeElem expertResult = arrayConvert(inputElemArr[ki]); threadOutput = threadOutput + scaleFloatArr[ki] * expertResult; } @@ -871,69 +880,6 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { } } -template -__global__ void finalizeKernelVecLoad_baseline(KernelParams params) { - using Type = typename KernelParams::Type; - using TypeExpW = typename KernelParams::TypeExpW; - - int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits::value; - int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits::value; - assert(hiddenDimPaddedBits % 128 == 0); - assert(hiddenDimBits % 128 == 0); - - // Load 128-bits per thread, according to the smallest data type we read/write - constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / cutlass::sizeof_bits::value; - using InputElem = cutlass::Array; - using OutputElem = cutlass::Array; - using ComputeElem = cutlass::Array; - - int64_t const tokenIdx = blockIdx.x; - int64_t const startOffset = threadIdx.x; - int64_t const stride = FINALIZE_THREADS_PER_BLOCK; - int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD; - int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD; - - auto const offset = tokenIdx * params.hiddenDim; - Type* outputPtr = params.outPtr + offset; - auto* outElemPtr = reinterpret_cast(outputPtr); - auto const* inElemPtr = reinterpret_cast(params.inPtr); - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - // wait on primary kernel when using PDL - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } -#endif - - for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) { - ComputeElem threadOutput; - threadOutput.fill(0); - for (int k = 0; k < params.topK; ++k) { - int const expandedIdx = tokenIdx * params.topK + k; - int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; - if (permutedIdx == -1) { - continue; - } - - float const scale = (params.expertWeightsPtr != nullptr) - ? static_cast(params.expertWeightsPtr[expandedIdx]) - : 1.f; - - auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; - - float4 input = - vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); - InputElem inputPermutedElem = *reinterpret_cast(&input); - ComputeElem expertResult = arrayConvert(inputPermutedElem); - - threadOutput = threadOutput + scale * expertResult; - } - - OutputElem outputElem = arrayConvert(threadOutput); - outElemPtr[elemIndex] = outputElem; - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -1027,9 +973,12 @@ void run(Data const& data, void* stream) { dim3 numBlocks(numBlocksX, numBlocksY); LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); } else { - FLASHINFER_CHECK(data.topK <= MaxTopK, "Finalize kernel with vectorized loading is not supported for this TopK value: %d", data.topK); - LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad_baseline, /*numBlocks=*/data.numTokens, - /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); + FLASHINFER_CHECK( + data.topK <= MaxTopK, + "Finalize kernel with vectorized loading is not supported for this TopK value: %d", + data.topK); + LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens, + /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); } } } diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 99824651b8..b5ff5757c9 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -517,16 +517,16 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d auto const& config = mPassingConfigs[configIndex]; - // mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, - // args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, - // args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, - // args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, - // workspace.gemm1_output_scale, args.top_k, args.hidden_size, - // args.intermediate_size, args.local_num_experts, args.num_tokens, - // workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, - // workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, - // workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, - // args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config, enable_pdl); + mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, + args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, + args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, + args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, + workspace.gemm1_output_scale, args.top_k, args.hidden_size, + args.intermediate_size, args.local_num_experts, args.num_tokens, + workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, + workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, + workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, + args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config, enable_pdl); // We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint. void* gemm2_input = workspace.gemm1_output; @@ -540,13 +540,13 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d } // Run gemm2 - // mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, - // args.output2_scales_scalar, args.gemm2_bias, workspace.gemm2_output, - // workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size, - // args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, - // workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, - // workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, - // config.gemm2Config, enable_pdl); + mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, + args.output2_scales_scalar, args.gemm2_bias, workspace.gemm2_output, + workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size, + args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, + workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, + workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, + config.gemm2Config, enable_pdl); // Run finalize if (args.do_finalize) { diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index b38a4dc43b..f0056335f0 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -116,38 +116,38 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeElt"); \ } -#define LAUNCH_EXPW(data, kernel, topK, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float, topK), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float, topK), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float, topK), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t, topK), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t, topK), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, topK), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported pair"); \ +#define LAUNCH_EXPW(data, kernel, topK, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported pair"); \ } -#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.topK % 4 == 0) { \ - LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ - } else if (data.topK % 2 == 0) { \ - LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ - } else if (data.topK % 1 == 0) { \ - LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported topK"); \ +#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.topK % 4 == 0) { \ + LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ + } else if (data.topK % 2 == 0) { \ + LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ + } else if (data.topK % 1 == 0) { \ + LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported topK"); \ } #define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ From 314581932b83907da1e2a7978b1aed2d7657698b Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Fri, 14 Nov 2025 07:56:40 -0800 Subject: [PATCH 3/4] pre-commit --- flashinfer/jit/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index ddcc8308f9..3c050fa1e5 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -55,7 +55,7 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec: "-DENABLE_FP8", "-DENABLE_FP4", "-DUSING_OSS_CUTLASS_MOE_GEMM", - "-lineinfo" + "-lineinfo", ] nvcc_flags += current_compilation_context.get_nvcc_flags_list( From b2ca5dfa23e4cf2878d2fca52aa8cdb705eeaf31 Mon Sep 17 00:00:00 2001 From: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Date: Fri, 14 Nov 2025 08:25:24 -0800 Subject: [PATCH 4/4] review --- csrc/trtllm_fused_moe_dev_kernel.cu | 11 ++++++----- flashinfer/jit/fused_moe.py | 1 - include/flashinfer/trtllm/fused_moe/DevKernel.h | 4 +--- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index 5a17b227b8..7a58042041 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -811,6 +811,7 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD; int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD; + bool const useScale = params.expertWeightsPtr != nullptr; __shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor]; __shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor]; @@ -820,10 +821,9 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor; auto permutedIdxPacked = reinterpret_cast( params.expandedIdxToPermutedIdx)[expandedIdx / TopKUnrollFactor]; - auto scalePacked = (params.expertWeightsPtr != nullptr) - ? reinterpret_cast( - params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] - : ScalePackedType{TypeExpW(1.f)}; + auto scalePacked = useScale ? reinterpret_cast( + params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] + : ScalePackedType{TypeExpW(1.f)}; scaleArrSmem[kChunkIdx] = scalePacked; permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked; @@ -871,8 +871,9 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { if (permutedIdx == -1) { continue; } + auto scale = useScale ? scaleFloatArr[ki] : 1.0f; ComputeElem expertResult = arrayConvert(inputElemArr[ki]); - threadOutput = threadOutput + scaleFloatArr[ki] * expertResult; + threadOutput = threadOutput + scale * expertResult; } } OutputElem outputElem = arrayConvert(threadOutput); diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 3c050fa1e5..78c19e98ac 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -55,7 +55,6 @@ def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec: "-DENABLE_FP8", "-DENABLE_FP4", "-DUSING_OSS_CUTLASS_MOE_GEMM", - "-lineinfo", ] nvcc_flags += current_compilation_context.get_nvcc_flags_list( diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index f0056335f0..23abb87a7b 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -144,10 +144,8 @@ namespace moe::dev { LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ } else if (data.topK % 2 == 0) { \ LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ - } else if (data.topK % 1 == 0) { \ - LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ } else { \ - FLASHINFER_WARN("Unsupported topK"); \ + LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ } #define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \