diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index 9a51384090..7a58042041 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -672,11 +672,128 @@ __device__ float4 vectorizedLoadPtx(float4 const* ptr) { // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip // connection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int MaxTopK = 64; + +typedef struct __CUDA_ALIGN__(4) { + cutlass::bfloat16_t array[2]; +} bfloat16_2; + +typedef struct __CUDA_ALIGN__(8) { + cutlass::bfloat16_t array[4]; +} bfloat16_4; + +typedef struct __CUDA_ALIGN__(8) { + half array[4]; +} half_4; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ScaleTraitsStruct; + +template <> +struct ScaleTraitsStruct<1, cutlass::bfloat16_t> { + using PackedType = cutlass::bfloat16_t; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, cutlass::bfloat16_t> { + using PackedType = bfloat16_2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, cutlass::bfloat16_t> { + using PackedType = bfloat16_4; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<1, float> { + using PackedType = float; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, float> { + using PackedType = float2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, float> { + using PackedType = float4; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<1, half> { + using PackedType = half; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<2, half> { + using PackedType = half2; + using ArrayType = cutlass::Array; +}; + +template <> +struct ScaleTraitsStruct<4, half> { + using PackedType = half_4; + using ArrayType = cutlass::Array; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FinalizeTraits; + +template +struct FinalizeTraits<1, TypeExpW_> { + using IdxPackedType = int; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<1, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<2, TypeExpW_> { + using IdxPackedType = int2; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<2, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +template +struct FinalizeTraits<4, TypeExpW_> { + using IdxPackedType = int4; + using IdxArrayType = cutlass::Array; + using ScaleTraits = ScaleTraitsStruct<4, TypeExpW_>; + using ScalePackedType = typename ScaleTraits::PackedType; + using ScaleArrayType = typename ScaleTraits::ArrayType; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void finalizeKernelVecLoad(KernelParams params) { using Type = typename KernelParams::Type; using TypeExpW = typename KernelParams::TypeExpW; + int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor; + + static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4, + "TopKUnrollFactor must be 1, 2, or 4"); + using FinalizeTraits = FinalizeTraits; + using IdxPackedType = typename FinalizeTraits::IdxPackedType; + using IdxArrayType = typename FinalizeTraits::IdxArrayType; + using ScalePackedType = typename FinalizeTraits::ScalePackedType; + using ScaleArrayType = typename FinalizeTraits::ScaleArrayType; int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits::value; int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits::value; @@ -694,6 +811,23 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD; int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD; + bool const useScale = params.expertWeightsPtr != nullptr; + + __shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor]; + __shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor]; + + for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor; + kChunkIdx += blockDim.x) { + int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor; + auto permutedIdxPacked = reinterpret_cast( + params.expandedIdxToPermutedIdx)[expandedIdx / TopKUnrollFactor]; + auto scalePacked = useScale ? reinterpret_cast( + params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor] + : ScalePackedType{TypeExpW(1.f)}; + + scaleArrSmem[kChunkIdx] = scalePacked; + permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked; + } auto const offset = tokenIdx * params.hiddenDim; Type* outputPtr = params.outPtr + offset; @@ -706,31 +840,42 @@ __global__ void finalizeKernelVecLoad(KernelParams params) { cudaGridDependencySynchronize(); } #endif + __syncthreads(); for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) { ComputeElem threadOutput; threadOutput.fill(0); - for (int k = 0; k < params.topK; ++k) { - int const expandedIdx = tokenIdx * params.topK + k; - int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; - if (permutedIdx == -1) { - continue; - } - - float const scale = (params.expertWeightsPtr != nullptr) - ? static_cast(params.expertWeightsPtr[expandedIdx]) - : 1.f; + for (int kChunkIdx = 0; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx++) { + auto permutedIdxArr = *reinterpret_cast(&permutedIdxArrSmem[kChunkIdx]); + InputElem inputElemArr[TopKUnrollFactor]; +#pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } - auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; + auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol; - float4 input = - vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); - InputElem inputPermutedElem = *reinterpret_cast(&input); - ComputeElem expertResult = arrayConvert(inputPermutedElem); + float4 input = + vectorizedLoadPtx(reinterpret_cast(&inputPermutedPtr[elemIndex])); + inputElemArr[ki] = *reinterpret_cast(&input); + } + auto scaleArr = *reinterpret_cast(&scaleArrSmem[kChunkIdx]); + auto const scaleFloatArr = + arrayConvert>(scaleArr); - threadOutput = threadOutput + scale * expertResult; +#pragma unroll + for (int ki = 0; ki < TopKUnrollFactor; ++ki) { + auto const permutedIdx = permutedIdxArr[ki]; + if (permutedIdx == -1) { + continue; + } + auto scale = useScale ? scaleFloatArr[ki] : 1.0f; + ComputeElem expertResult = arrayConvert(inputElemArr[ki]); + threadOutput = threadOutput + scale * expertResult; + } } - OutputElem outputElem = arrayConvert(threadOutput); outElemPtr[elemIndex] = outputElem; } @@ -813,7 +958,7 @@ void run(Data const& data, void* stream) { int const numBlocksY = std::min(8192, data.numTokens); dim3 numBlocks(numBlocksX, numBlocksY); - LAUNCH_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream); + LAUNCH_TOPK_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream); } else { int const numThreads = 256; int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads; @@ -827,10 +972,14 @@ void run(Data const& data, void* stream) { // ensure that when the number of waves is greater than 1, we choose to use the kernel with // vectorized loading. dim3 numBlocks(numBlocksX, numBlocksY); - LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); + LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream); } else { - LAUNCH_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens, - /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); + FLASHINFER_CHECK( + data.topK <= MaxTopK, + "Finalize kernel with vectorized loading is not supported for this TopK value: %d", + data.topK); + LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens, + /*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream); } } } diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 0ee9ba6fe9..23abb87a7b 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -116,27 +116,36 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeElt"); \ } -#define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, \ - smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported pair"); \ +#define LAUNCH_EXPW(data, kernel, topK, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float, topK), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, topK), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported pair"); \ + } + +#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.topK % 4 == 0) { \ + LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \ + } else if (data.topK % 2 == 0) { \ + LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \ + } else { \ + LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \ } #define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ @@ -453,10 +462,11 @@ struct Data { int32_t const* totalNumPaddedTokens; }; -template +template struct KernelParams { using Type = Type_; using TypeExpW = TypeExpW_; + static constexpr int TopKUnrollFactor = TopKUnrollFactor_; static constexpr bool UsePdl = UsePdl_; Type const* inPtr;