@@ -672,11 +672,128 @@ __device__ float4 vectorizedLoadPtx(float4 const* ptr) {
672672// Final kernel to unpermute and scale
673673// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
674674// connection.
675+ // //////////////////////////////////////////////////////////////////////////////////////////////////
676+
677+ constexpr int MaxTopK = 64 ;
678+
679+ typedef struct __CUDA_ALIGN__ (4 ) {
680+ cutlass::bfloat16_t array[2 ];
681+ } bfloat16_2;
682+
683+ typedef struct __CUDA_ALIGN__ (8 ) {
684+ cutlass::bfloat16_t array[4 ];
685+ } bfloat16_4;
686+
687+ typedef struct __CUDA_ALIGN__ (8 ) {
688+ half array[4 ];
689+ } half_4;
690+
691+ // //////////////////////////////////////////////////////////////////////////////////////////////////
692+
693+ template <int UnrollFactor_, typename TypeExpW_>
694+ struct ScaleTraitsStruct ;
695+
696+ template <>
697+ struct ScaleTraitsStruct <1 , cutlass::bfloat16_t > {
698+ using PackedType = cutlass::bfloat16_t ;
699+ using ArrayType = cutlass::Array<cutlass::bfloat16_t , 1 >;
700+ };
701+
702+ template <>
703+ struct ScaleTraitsStruct <2 , cutlass::bfloat16_t > {
704+ using PackedType = bfloat16_2;
705+ using ArrayType = cutlass::Array<cutlass::bfloat16_t , 2 >;
706+ };
707+
708+ template <>
709+ struct ScaleTraitsStruct <4 , cutlass::bfloat16_t > {
710+ using PackedType = bfloat16_4;
711+ using ArrayType = cutlass::Array<cutlass::bfloat16_t , 4 >;
712+ };
713+
714+ template <>
715+ struct ScaleTraitsStruct <1 , float > {
716+ using PackedType = float ;
717+ using ArrayType = cutlass::Array<float , 1 >;
718+ };
719+
720+ template <>
721+ struct ScaleTraitsStruct <2 , float > {
722+ using PackedType = float2 ;
723+ using ArrayType = cutlass::Array<float , 2 >;
724+ };
725+
726+ template <>
727+ struct ScaleTraitsStruct <4 , float > {
728+ using PackedType = float4 ;
729+ using ArrayType = cutlass::Array<float , 4 >;
730+ };
731+
732+ template <>
733+ struct ScaleTraitsStruct <1 , half> {
734+ using PackedType = half;
735+ using ArrayType = cutlass::Array<half, 1 >;
736+ };
737+
738+ template <>
739+ struct ScaleTraitsStruct <2 , half> {
740+ using PackedType = half2;
741+ using ArrayType = cutlass::Array<half, 2 >;
742+ };
743+
744+ template <>
745+ struct ScaleTraitsStruct <4 , half> {
746+ using PackedType = half_4;
747+ using ArrayType = cutlass::Array<half, 4 >;
748+ };
749+
750+ // //////////////////////////////////////////////////////////////////////////////////////////////////
751+
752+ template <int UnrollFactor_, typename TypeExpW_>
753+ struct FinalizeTraits ;
754+
755+ template <typename TypeExpW_>
756+ struct FinalizeTraits <1 , TypeExpW_> {
757+ using IdxPackedType = int ;
758+ using IdxArrayType = cutlass::Array<int , 1 >;
759+ using ScaleTraits = ScaleTraitsStruct<1 , TypeExpW_>;
760+ using ScalePackedType = typename ScaleTraits::PackedType;
761+ using ScaleArrayType = typename ScaleTraits::ArrayType;
762+ };
763+
764+ template <typename TypeExpW_>
765+ struct FinalizeTraits <2 , TypeExpW_> {
766+ using IdxPackedType = int2 ;
767+ using IdxArrayType = cutlass::Array<int , 2 >;
768+ using ScaleTraits = ScaleTraitsStruct<2 , TypeExpW_>;
769+ using ScalePackedType = typename ScaleTraits::PackedType;
770+ using ScaleArrayType = typename ScaleTraits::ArrayType;
771+ };
772+
773+ template <typename TypeExpW_>
774+ struct FinalizeTraits <4 , TypeExpW_> {
775+ using IdxPackedType = int4 ;
776+ using IdxArrayType = cutlass::Array<int , 4 >;
777+ using ScaleTraits = ScaleTraitsStruct<4 , TypeExpW_>;
778+ using ScalePackedType = typename ScaleTraits::PackedType;
779+ using ScaleArrayType = typename ScaleTraits::ArrayType;
780+ };
781+
782+ // //////////////////////////////////////////////////////////////////////////////////////////////////
675783
676784template <typename KernelParams>
677785__global__ void finalizeKernelVecLoad (KernelParams params) {
678786 using Type = typename KernelParams::Type;
679787 using TypeExpW = typename KernelParams::TypeExpW;
788+ int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor;
789+
790+ static_assert (TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4 ,
791+ " TopKUnrollFactor must be 1, 2, or 4" );
792+ using FinalizeTraits = FinalizeTraits<TopKUnrollFactor, TypeExpW>;
793+ using IdxPackedType = typename FinalizeTraits::IdxPackedType;
794+ using IdxArrayType = typename FinalizeTraits::IdxArrayType;
795+ using ScalePackedType = typename FinalizeTraits::ScalePackedType;
796+ using ScaleArrayType = typename FinalizeTraits::ScaleArrayType;
680797
681798 int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits<Type>::value;
682799 int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits<Type>::value;
@@ -694,6 +811,23 @@ __global__ void finalizeKernelVecLoad(KernelParams params) {
694811 int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
695812 int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD;
696813 int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD;
814+ bool const useScale = params.expertWeightsPtr != nullptr ;
815+
816+ __shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor];
817+ __shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor];
818+
819+ for (int kChunkIdx = threadIdx .x ; kChunkIdx < params.topK / TopKUnrollFactor;
820+ kChunkIdx += blockDim .x ) {
821+ int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor;
822+ auto permutedIdxPacked = reinterpret_cast <IdxPackedType const *>(
823+ params.expandedIdxToPermutedIdx )[expandedIdx / TopKUnrollFactor];
824+ auto scalePacked = useScale ? reinterpret_cast <ScalePackedType const *>(
825+ params.expertWeightsPtr )[expandedIdx / TopKUnrollFactor]
826+ : ScalePackedType{TypeExpW (1 .f )};
827+
828+ scaleArrSmem[kChunkIdx ] = scalePacked;
829+ permutedIdxArrSmem[kChunkIdx ] = permutedIdxPacked;
830+ }
697831
698832 auto const offset = tokenIdx * params.hiddenDim ;
699833 Type* outputPtr = params.outPtr + offset;
@@ -706,31 +840,42 @@ __global__ void finalizeKernelVecLoad(KernelParams params) {
706840 cudaGridDependencySynchronize ();
707841 }
708842#endif
843+ __syncthreads ();
709844
710845 for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) {
711846 ComputeElem threadOutput;
712847 threadOutput.fill (0 );
713- for (int k = 0 ; k < params.topK ; ++k) {
714- int const expandedIdx = tokenIdx * params.topK + k;
715- int const permutedIdx = params.expandedIdxToPermutedIdx [expandedIdx];
716- if (permutedIdx == -1 ) {
717- continue ;
718- }
719-
720- float const scale = (params.expertWeightsPtr != nullptr )
721- ? static_cast <float >(params.expertWeightsPtr [expandedIdx])
722- : 1 .f ;
848+ for (int kChunkIdx = 0 ; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx ++) {
849+ auto permutedIdxArr = *reinterpret_cast <IdxArrayType const *>(&permutedIdxArrSmem[kChunkIdx ]);
850+ InputElem inputElemArr[TopKUnrollFactor];
851+ #pragma unroll
852+ for (int ki = 0 ; ki < TopKUnrollFactor; ++ki) {
853+ auto const permutedIdx = permutedIdxArr[ki];
854+ if (permutedIdx == -1 ) {
855+ continue ;
856+ }
723857
724- auto const * inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol;
858+ auto const * inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol;
725859
726- float4 input =
727- vectorizedLoadPtx (reinterpret_cast <float4 const *>(&inputPermutedPtr[elemIndex]));
728- InputElem inputPermutedElem = *reinterpret_cast <InputElem const *>(&input);
729- ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputPermutedElem);
860+ float4 input =
861+ vectorizedLoadPtx (reinterpret_cast <float4 const *>(&inputPermutedPtr[elemIndex]));
862+ inputElemArr[ki] = *reinterpret_cast <InputElem const *>(&input);
863+ }
864+ auto scaleArr = *reinterpret_cast <ScaleArrayType const *>(&scaleArrSmem[kChunkIdx ]);
865+ auto const scaleFloatArr =
866+ arrayConvert<ScaleArrayType, cutlass::Array<float , TopKUnrollFactor>>(scaleArr);
730867
731- threadOutput = threadOutput + scale * expertResult;
868+ #pragma unroll
869+ for (int ki = 0 ; ki < TopKUnrollFactor; ++ki) {
870+ auto const permutedIdx = permutedIdxArr[ki];
871+ if (permutedIdx == -1 ) {
872+ continue ;
873+ }
874+ auto scale = useScale ? scaleFloatArr[ki] : 1 .0f ;
875+ ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputElemArr[ki]);
876+ threadOutput = threadOutput + scale * expertResult;
877+ }
732878 }
733-
734879 OutputElem outputElem = arrayConvert<ComputeElem, OutputElem>(threadOutput);
735880 outElemPtr[elemIndex] = outputElem;
736881 }
@@ -813,7 +958,7 @@ void run(Data const& data, void* stream) {
813958 int const numBlocksY = std::min (8192 , data.numTokens );
814959 dim3 numBlocks (numBlocksX, numBlocksY);
815960
816- LAUNCH_EXPW (data, finalizeDeepSeekKernel, numBlocks, numThreads, 0 , stream);
961+ LAUNCH_TOPK_EXPW (data, finalizeDeepSeekKernel, numBlocks, numThreads, 0 , stream);
817962 } else {
818963 int const numThreads = 256 ;
819964 int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads;
@@ -827,10 +972,14 @@ void run(Data const& data, void* stream) {
827972 // ensure that when the number of waves is greater than 1, we choose to use the kernel with
828973 // vectorized loading.
829974 dim3 numBlocks (numBlocksX, numBlocksY);
830- LAUNCH_EXPW (data, finalizeKernel, numBlocks, numThreads, 0 , stream);
975+ LAUNCH_TOPK_EXPW (data, finalizeKernel, numBlocks, numThreads, 0 , stream);
831976 } else {
832- LAUNCH_EXPW (data, finalizeKernelVecLoad, /* numBlocks=*/ data.numTokens ,
833- /* numThreads=*/ FINALIZE_THREADS_PER_BLOCK, 0 , stream);
977+ FLASHINFER_CHECK (
978+ data.topK <= MaxTopK,
979+ " Finalize kernel with vectorized loading is not supported for this TopK value: %d" ,
980+ data.topK );
981+ LAUNCH_TOPK_EXPW (data, finalizeKernelVecLoad, /* numBlocks=*/ data.numTokens ,
982+ /* numThreads=*/ FINALIZE_THREADS_PER_BLOCK, 0 , stream);
834983 }
835984 }
836985}
0 commit comments