Skip to content

Commit cce4952

Browse files
authored
perf: TRT-LLM Gen finalize kernel optimization (#2092)
<!-- .github/pull_request_template.md --> ## 📌 Description - Small optimization for TRT-LLM Gen MoE finalize kernel TopK=8, NumExperts=128, HiddenSize=4096 | BS | Baseline, us | Optimized, us | Speed-up | | ------------- | ------------- | ------------- | ------------- | | 256 | 11 | 6 | 1.83 | | 512 | 12 | 7 | 1.71 | | 1024 | 16 | 15 | 1.06 | | 4096 | 55 | 49 | 1.12 | | 8192 | 107 | 95 | 1.13 | | 16384 | 205 | 183 | 1.12 | <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enabled vectorized, Top-K unrolled finalize path for MOE (Mixture of Experts) kernel operations with improved performance. * Added support for multiple data types (bfloat16, float, half) with enhanced type specialization and packing. * Introduced runtime validation for TopK configurations (≤ 64) to ensure optimal vectorized execution. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent ba8f3ed commit cce4952

File tree

2 files changed

+202
-43
lines changed

2 files changed

+202
-43
lines changed

csrc/trtllm_fused_moe_dev_kernel.cu

Lines changed: 170 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,128 @@ __device__ float4 vectorizedLoadPtx(float4 const* ptr) {
672672
// Final kernel to unpermute and scale
673673
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
674674
// connection.
675+
////////////////////////////////////////////////////////////////////////////////////////////////////
676+
677+
constexpr int MaxTopK = 64;
678+
679+
typedef struct __CUDA_ALIGN__(4) {
680+
cutlass::bfloat16_t array[2];
681+
} bfloat16_2;
682+
683+
typedef struct __CUDA_ALIGN__(8) {
684+
cutlass::bfloat16_t array[4];
685+
} bfloat16_4;
686+
687+
typedef struct __CUDA_ALIGN__(8) {
688+
half array[4];
689+
} half_4;
690+
691+
////////////////////////////////////////////////////////////////////////////////////////////////////
692+
693+
template <int UnrollFactor_, typename TypeExpW_>
694+
struct ScaleTraitsStruct;
695+
696+
template <>
697+
struct ScaleTraitsStruct<1, cutlass::bfloat16_t> {
698+
using PackedType = cutlass::bfloat16_t;
699+
using ArrayType = cutlass::Array<cutlass::bfloat16_t, 1>;
700+
};
701+
702+
template <>
703+
struct ScaleTraitsStruct<2, cutlass::bfloat16_t> {
704+
using PackedType = bfloat16_2;
705+
using ArrayType = cutlass::Array<cutlass::bfloat16_t, 2>;
706+
};
707+
708+
template <>
709+
struct ScaleTraitsStruct<4, cutlass::bfloat16_t> {
710+
using PackedType = bfloat16_4;
711+
using ArrayType = cutlass::Array<cutlass::bfloat16_t, 4>;
712+
};
713+
714+
template <>
715+
struct ScaleTraitsStruct<1, float> {
716+
using PackedType = float;
717+
using ArrayType = cutlass::Array<float, 1>;
718+
};
719+
720+
template <>
721+
struct ScaleTraitsStruct<2, float> {
722+
using PackedType = float2;
723+
using ArrayType = cutlass::Array<float, 2>;
724+
};
725+
726+
template <>
727+
struct ScaleTraitsStruct<4, float> {
728+
using PackedType = float4;
729+
using ArrayType = cutlass::Array<float, 4>;
730+
};
731+
732+
template <>
733+
struct ScaleTraitsStruct<1, half> {
734+
using PackedType = half;
735+
using ArrayType = cutlass::Array<half, 1>;
736+
};
737+
738+
template <>
739+
struct ScaleTraitsStruct<2, half> {
740+
using PackedType = half2;
741+
using ArrayType = cutlass::Array<half, 2>;
742+
};
743+
744+
template <>
745+
struct ScaleTraitsStruct<4, half> {
746+
using PackedType = half_4;
747+
using ArrayType = cutlass::Array<half, 4>;
748+
};
749+
750+
////////////////////////////////////////////////////////////////////////////////////////////////////
751+
752+
template <int UnrollFactor_, typename TypeExpW_>
753+
struct FinalizeTraits;
754+
755+
template <typename TypeExpW_>
756+
struct FinalizeTraits<1, TypeExpW_> {
757+
using IdxPackedType = int;
758+
using IdxArrayType = cutlass::Array<int, 1>;
759+
using ScaleTraits = ScaleTraitsStruct<1, TypeExpW_>;
760+
using ScalePackedType = typename ScaleTraits::PackedType;
761+
using ScaleArrayType = typename ScaleTraits::ArrayType;
762+
};
763+
764+
template <typename TypeExpW_>
765+
struct FinalizeTraits<2, TypeExpW_> {
766+
using IdxPackedType = int2;
767+
using IdxArrayType = cutlass::Array<int, 2>;
768+
using ScaleTraits = ScaleTraitsStruct<2, TypeExpW_>;
769+
using ScalePackedType = typename ScaleTraits::PackedType;
770+
using ScaleArrayType = typename ScaleTraits::ArrayType;
771+
};
772+
773+
template <typename TypeExpW_>
774+
struct FinalizeTraits<4, TypeExpW_> {
775+
using IdxPackedType = int4;
776+
using IdxArrayType = cutlass::Array<int, 4>;
777+
using ScaleTraits = ScaleTraitsStruct<4, TypeExpW_>;
778+
using ScalePackedType = typename ScaleTraits::PackedType;
779+
using ScaleArrayType = typename ScaleTraits::ArrayType;
780+
};
781+
782+
////////////////////////////////////////////////////////////////////////////////////////////////////
675783

676784
template <typename KernelParams>
677785
__global__ void finalizeKernelVecLoad(KernelParams params) {
678786
using Type = typename KernelParams::Type;
679787
using TypeExpW = typename KernelParams::TypeExpW;
788+
int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor;
789+
790+
static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4,
791+
"TopKUnrollFactor must be 1, 2, or 4");
792+
using FinalizeTraits = FinalizeTraits<TopKUnrollFactor, TypeExpW>;
793+
using IdxPackedType = typename FinalizeTraits::IdxPackedType;
794+
using IdxArrayType = typename FinalizeTraits::IdxArrayType;
795+
using ScalePackedType = typename FinalizeTraits::ScalePackedType;
796+
using ScaleArrayType = typename FinalizeTraits::ScaleArrayType;
680797

681798
int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits<Type>::value;
682799
int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits<Type>::value;
@@ -694,6 +811,23 @@ __global__ void finalizeKernelVecLoad(KernelParams params) {
694811
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
695812
int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD;
696813
int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD;
814+
bool const useScale = params.expertWeightsPtr != nullptr;
815+
816+
__shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor];
817+
__shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor];
818+
819+
for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor;
820+
kChunkIdx += blockDim.x) {
821+
int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor;
822+
auto permutedIdxPacked = reinterpret_cast<IdxPackedType const*>(
823+
params.expandedIdxToPermutedIdx)[expandedIdx / TopKUnrollFactor];
824+
auto scalePacked = useScale ? reinterpret_cast<ScalePackedType const*>(
825+
params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor]
826+
: ScalePackedType{TypeExpW(1.f)};
827+
828+
scaleArrSmem[kChunkIdx] = scalePacked;
829+
permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked;
830+
}
697831

698832
auto const offset = tokenIdx * params.hiddenDim;
699833
Type* outputPtr = params.outPtr + offset;
@@ -706,31 +840,42 @@ __global__ void finalizeKernelVecLoad(KernelParams params) {
706840
cudaGridDependencySynchronize();
707841
}
708842
#endif
843+
__syncthreads();
709844

710845
for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) {
711846
ComputeElem threadOutput;
712847
threadOutput.fill(0);
713-
for (int k = 0; k < params.topK; ++k) {
714-
int const expandedIdx = tokenIdx * params.topK + k;
715-
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
716-
if (permutedIdx == -1) {
717-
continue;
718-
}
719-
720-
float const scale = (params.expertWeightsPtr != nullptr)
721-
? static_cast<float>(params.expertWeightsPtr[expandedIdx])
722-
: 1.f;
848+
for (int kChunkIdx = 0; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx++) {
849+
auto permutedIdxArr = *reinterpret_cast<IdxArrayType const*>(&permutedIdxArrSmem[kChunkIdx]);
850+
InputElem inputElemArr[TopKUnrollFactor];
851+
#pragma unroll
852+
for (int ki = 0; ki < TopKUnrollFactor; ++ki) {
853+
auto const permutedIdx = permutedIdxArr[ki];
854+
if (permutedIdx == -1) {
855+
continue;
856+
}
723857

724-
auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol;
858+
auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol;
725859

726-
float4 input =
727-
vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));
728-
InputElem inputPermutedElem = *reinterpret_cast<InputElem const*>(&input);
729-
ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputPermutedElem);
860+
float4 input =
861+
vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));
862+
inputElemArr[ki] = *reinterpret_cast<InputElem const*>(&input);
863+
}
864+
auto scaleArr = *reinterpret_cast<ScaleArrayType const*>(&scaleArrSmem[kChunkIdx]);
865+
auto const scaleFloatArr =
866+
arrayConvert<ScaleArrayType, cutlass::Array<float, TopKUnrollFactor>>(scaleArr);
730867

731-
threadOutput = threadOutput + scale * expertResult;
868+
#pragma unroll
869+
for (int ki = 0; ki < TopKUnrollFactor; ++ki) {
870+
auto const permutedIdx = permutedIdxArr[ki];
871+
if (permutedIdx == -1) {
872+
continue;
873+
}
874+
auto scale = useScale ? scaleFloatArr[ki] : 1.0f;
875+
ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputElemArr[ki]);
876+
threadOutput = threadOutput + scale * expertResult;
877+
}
732878
}
733-
734879
OutputElem outputElem = arrayConvert<ComputeElem, OutputElem>(threadOutput);
735880
outElemPtr[elemIndex] = outputElem;
736881
}
@@ -813,7 +958,7 @@ void run(Data const& data, void* stream) {
813958
int const numBlocksY = std::min(8192, data.numTokens);
814959
dim3 numBlocks(numBlocksX, numBlocksY);
815960

816-
LAUNCH_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream);
961+
LAUNCH_TOPK_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream);
817962
} else {
818963
int const numThreads = 256;
819964
int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads;
@@ -827,10 +972,14 @@ void run(Data const& data, void* stream) {
827972
// ensure that when the number of waves is greater than 1, we choose to use the kernel with
828973
// vectorized loading.
829974
dim3 numBlocks(numBlocksX, numBlocksY);
830-
LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
975+
LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
831976
} else {
832-
LAUNCH_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens,
833-
/*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream);
977+
FLASHINFER_CHECK(
978+
data.topK <= MaxTopK,
979+
"Finalize kernel with vectorized loading is not supported for this TopK value: %d",
980+
data.topK);
981+
LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens,
982+
/*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream);
834983
}
835984
}
836985
}

include/flashinfer/trtllm/fused_moe/DevKernel.h

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -116,27 +116,36 @@ namespace moe::dev {
116116
FLASHINFER_WARN("Unsupported dtypeElt"); \
117117
}
118118

119-
#define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
120-
if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
121-
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, \
122-
smemSize, stream); \
123-
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \
124-
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, \
125-
numThreads, smemSize, stream); \
126-
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
127-
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, \
128-
smemSize, stream); \
129-
} else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
130-
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, \
131-
numThreads, smemSize, stream); \
132-
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
133-
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, \
134-
numBlocks, numThreads, smemSize, stream); \
135-
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
136-
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, \
137-
numBlocks, numThreads, smemSize, stream); \
138-
} else { \
139-
FLASHINFER_WARN("Unsupported pair"); \
119+
#define LAUNCH_EXPW(data, kernel, topK, numBlocks, numThreads, smemSize, stream) \
120+
if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
121+
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float, topK), kernel, numBlocks, \
122+
numThreads, smemSize, stream); \
123+
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \
124+
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float, topK), kernel, numBlocks, \
125+
numThreads, smemSize, stream); \
126+
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
127+
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float, topK), kernel, numBlocks, \
128+
numThreads, smemSize, stream); \
129+
} else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
130+
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t, topK), kernel, \
131+
numBlocks, numThreads, smemSize, stream); \
132+
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
133+
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t, topK), kernel, \
134+
numBlocks, numThreads, smemSize, stream); \
135+
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
136+
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, topK), kernel, \
137+
numBlocks, numThreads, smemSize, stream); \
138+
} else { \
139+
FLASHINFER_WARN("Unsupported pair"); \
140+
}
141+
142+
#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
143+
if (data.topK % 4 == 0) { \
144+
LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \
145+
} else if (data.topK % 2 == 0) { \
146+
LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \
147+
} else { \
148+
LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \
140149
}
141150

142151
#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \
@@ -453,10 +462,11 @@ struct Data {
453462
int32_t const* totalNumPaddedTokens;
454463
};
455464

456-
template <typename Type_, typename TypeExpW_, bool UsePdl_>
465+
template <typename Type_, typename TypeExpW_, int TopKUnrollFactor_, bool UsePdl_>
457466
struct KernelParams {
458467
using Type = Type_;
459468
using TypeExpW = TypeExpW_;
469+
static constexpr int TopKUnrollFactor = TopKUnrollFactor_;
460470
static constexpr bool UsePdl = UsePdl_;
461471

462472
Type const* inPtr;

0 commit comments

Comments
 (0)