Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 170 additions & 21 deletions csrc/trtllm_fused_moe_dev_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,128 @@ __device__ float4 vectorizedLoadPtx(float4 const* ptr) {
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
// connection.
////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr int MaxTopK = 64;

typedef struct __CUDA_ALIGN__(4) {
cutlass::bfloat16_t array[2];
} bfloat16_2;

typedef struct __CUDA_ALIGN__(8) {
cutlass::bfloat16_t array[4];
} bfloat16_4;

typedef struct __CUDA_ALIGN__(8) {
half array[4];
} half_4;

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int UnrollFactor_, typename TypeExpW_>
struct ScaleTraitsStruct;

template <>
struct ScaleTraitsStruct<1, cutlass::bfloat16_t> {
using PackedType = cutlass::bfloat16_t;
using ArrayType = cutlass::Array<cutlass::bfloat16_t, 1>;
};

template <>
struct ScaleTraitsStruct<2, cutlass::bfloat16_t> {
using PackedType = bfloat16_2;
using ArrayType = cutlass::Array<cutlass::bfloat16_t, 2>;
};

template <>
struct ScaleTraitsStruct<4, cutlass::bfloat16_t> {
using PackedType = bfloat16_4;
using ArrayType = cutlass::Array<cutlass::bfloat16_t, 4>;
};

template <>
struct ScaleTraitsStruct<1, float> {
using PackedType = float;
using ArrayType = cutlass::Array<float, 1>;
};

template <>
struct ScaleTraitsStruct<2, float> {
using PackedType = float2;
using ArrayType = cutlass::Array<float, 2>;
};

template <>
struct ScaleTraitsStruct<4, float> {
using PackedType = float4;
using ArrayType = cutlass::Array<float, 4>;
};

template <>
struct ScaleTraitsStruct<1, half> {
using PackedType = half;
using ArrayType = cutlass::Array<half, 1>;
};

template <>
struct ScaleTraitsStruct<2, half> {
using PackedType = half2;
using ArrayType = cutlass::Array<half, 2>;
};

template <>
struct ScaleTraitsStruct<4, half> {
using PackedType = half_4;
using ArrayType = cutlass::Array<half, 4>;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int UnrollFactor_, typename TypeExpW_>
struct FinalizeTraits;

template <typename TypeExpW_>
struct FinalizeTraits<1, TypeExpW_> {
using IdxPackedType = int;
using IdxArrayType = cutlass::Array<int, 1>;
using ScaleTraits = ScaleTraitsStruct<1, TypeExpW_>;
using ScalePackedType = typename ScaleTraits::PackedType;
using ScaleArrayType = typename ScaleTraits::ArrayType;
};

template <typename TypeExpW_>
struct FinalizeTraits<2, TypeExpW_> {
using IdxPackedType = int2;
using IdxArrayType = cutlass::Array<int, 2>;
using ScaleTraits = ScaleTraitsStruct<2, TypeExpW_>;
using ScalePackedType = typename ScaleTraits::PackedType;
using ScaleArrayType = typename ScaleTraits::ArrayType;
};

template <typename TypeExpW_>
struct FinalizeTraits<4, TypeExpW_> {
using IdxPackedType = int4;
using IdxArrayType = cutlass::Array<int, 4>;
using ScaleTraits = ScaleTraitsStruct<4, TypeExpW_>;
using ScalePackedType = typename ScaleTraits::PackedType;
using ScaleArrayType = typename ScaleTraits::ArrayType;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename KernelParams>
__global__ void finalizeKernelVecLoad(KernelParams params) {
using Type = typename KernelParams::Type;
using TypeExpW = typename KernelParams::TypeExpW;
int constexpr TopKUnrollFactor = KernelParams::TopKUnrollFactor;

static_assert(TopKUnrollFactor == 1 || TopKUnrollFactor == 2 || TopKUnrollFactor == 4,
"TopKUnrollFactor must be 1, 2, or 4");
using FinalizeTraits = FinalizeTraits<TopKUnrollFactor, TypeExpW>;
using IdxPackedType = typename FinalizeTraits::IdxPackedType;
using IdxArrayType = typename FinalizeTraits::IdxArrayType;
using ScalePackedType = typename FinalizeTraits::ScalePackedType;
using ScaleArrayType = typename FinalizeTraits::ScaleArrayType;

int const hiddenDimPaddedBits = params.hiddenDimPadded * cutlass::sizeof_bits<Type>::value;
int const hiddenDimBits = params.hiddenDim * cutlass::sizeof_bits<Type>::value;
Expand All @@ -694,6 +811,23 @@ __global__ void finalizeKernelVecLoad(KernelParams params) {
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
int64_t const numElemsInPaddedCol = params.hiddenDimPadded / FINALIZE_ELEM_PER_THREAD;
int64_t const numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD;
bool const useScale = params.expertWeightsPtr != nullptr;

__shared__ ScalePackedType scaleArrSmem[MaxTopK / TopKUnrollFactor];
__shared__ IdxPackedType permutedIdxArrSmem[MaxTopK / TopKUnrollFactor];

for (int kChunkIdx = threadIdx.x; kChunkIdx < params.topK / TopKUnrollFactor;
kChunkIdx += blockDim.x) {
int const expandedIdx = tokenIdx * params.topK + kChunkIdx * TopKUnrollFactor;
auto permutedIdxPacked = reinterpret_cast<IdxPackedType const*>(
params.expandedIdxToPermutedIdx)[expandedIdx / TopKUnrollFactor];
auto scalePacked = useScale ? reinterpret_cast<ScalePackedType const*>(
params.expertWeightsPtr)[expandedIdx / TopKUnrollFactor]
: ScalePackedType{TypeExpW(1.f)};

scaleArrSmem[kChunkIdx] = scalePacked;
permutedIdxArrSmem[kChunkIdx] = permutedIdxPacked;
}

auto const offset = tokenIdx * params.hiddenDim;
Type* outputPtr = params.outPtr + offset;
Expand All @@ -706,31 +840,42 @@ __global__ void finalizeKernelVecLoad(KernelParams params) {
cudaGridDependencySynchronize();
}
#endif
__syncthreads();

for (int elemIndex = startOffset; elemIndex < numElemsInCol; elemIndex += stride) {
ComputeElem threadOutput;
threadOutput.fill(0);
for (int k = 0; k < params.topK; ++k) {
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
if (permutedIdx == -1) {
continue;
}

float const scale = (params.expertWeightsPtr != nullptr)
? static_cast<float>(params.expertWeightsPtr[expandedIdx])
: 1.f;
for (int kChunkIdx = 0; kChunkIdx < params.topK / TopKUnrollFactor; kChunkIdx++) {
auto permutedIdxArr = *reinterpret_cast<IdxArrayType const*>(&permutedIdxArrSmem[kChunkIdx]);
InputElem inputElemArr[TopKUnrollFactor];
#pragma unroll
for (int ki = 0; ki < TopKUnrollFactor; ++ki) {
auto const permutedIdx = permutedIdxArr[ki];
if (permutedIdx == -1) {
continue;
}

auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol;
auto const* inputPermutedPtr = inElemPtr + permutedIdx * numElemsInPaddedCol;

float4 input =
vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));
InputElem inputPermutedElem = *reinterpret_cast<InputElem const*>(&input);
ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputPermutedElem);
float4 input =
vectorizedLoadPtx(reinterpret_cast<float4 const*>(&inputPermutedPtr[elemIndex]));
inputElemArr[ki] = *reinterpret_cast<InputElem const*>(&input);
}
auto scaleArr = *reinterpret_cast<ScaleArrayType const*>(&scaleArrSmem[kChunkIdx]);
auto const scaleFloatArr =
arrayConvert<ScaleArrayType, cutlass::Array<float, TopKUnrollFactor>>(scaleArr);

threadOutput = threadOutput + scale * expertResult;
#pragma unroll
for (int ki = 0; ki < TopKUnrollFactor; ++ki) {
auto const permutedIdx = permutedIdxArr[ki];
if (permutedIdx == -1) {
continue;
}
auto scale = useScale ? scaleFloatArr[ki] : 1.0f;
ComputeElem expertResult = arrayConvert<InputElem, ComputeElem>(inputElemArr[ki]);
threadOutput = threadOutput + scale * expertResult;
}
}

OutputElem outputElem = arrayConvert<ComputeElem, OutputElem>(threadOutput);
outElemPtr[elemIndex] = outputElem;
}
Expand Down Expand Up @@ -813,7 +958,7 @@ void run(Data const& data, void* stream) {
int const numBlocksY = std::min(8192, data.numTokens);
dim3 numBlocks(numBlocksX, numBlocksY);

LAUNCH_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream);
LAUNCH_TOPK_EXPW(data, finalizeDeepSeekKernel, numBlocks, numThreads, 0, stream);
} else {
int const numThreads = 256;
int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads;
Expand All @@ -827,10 +972,14 @@ void run(Data const& data, void* stream) {
// ensure that when the number of waves is greater than 1, we choose to use the kernel with
// vectorized loading.
dim3 numBlocks(numBlocksX, numBlocksY);
LAUNCH_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
LAUNCH_TOPK_EXPW(data, finalizeKernel, numBlocks, numThreads, 0, stream);
} else {
LAUNCH_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens,
/*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream);
FLASHINFER_CHECK(
data.topK <= MaxTopK,
"Finalize kernel with vectorized loading is not supported for this TopK value: %d",
data.topK);
LAUNCH_TOPK_EXPW(data, finalizeKernelVecLoad, /*numBlocks=*/data.numTokens,
/*numThreads=*/FINALIZE_THREADS_PER_BLOCK, 0, stream);
}
}
}
Expand Down
54 changes: 32 additions & 22 deletions include/flashinfer/trtllm/fused_moe/DevKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,27 +116,36 @@ namespace moe::dev {
FLASHINFER_WARN("Unsupported dtypeElt"); \
}

#define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported pair"); \
#define LAUNCH_EXPW(data, kernel, topK, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float, topK), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float, topK), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, float, topK), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t, topK), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t, topK), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, topK), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported pair"); \
}

#define LAUNCH_TOPK_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.topK % 4 == 0) { \
LAUNCH_EXPW(data, kernel, 4, numBlocks, numThreads, smemSize, stream); \
} else if (data.topK % 2 == 0) { \
LAUNCH_EXPW(data, kernel, 2, numBlocks, numThreads, smemSize, stream); \
} else { \
LAUNCH_EXPW(data, kernel, 1, numBlocks, numThreads, smemSize, stream); \
}

#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \
Expand Down Expand Up @@ -453,10 +462,11 @@ struct Data {
int32_t const* totalNumPaddedTokens;
};

template <typename Type_, typename TypeExpW_, bool UsePdl_>
template <typename Type_, typename TypeExpW_, int TopKUnrollFactor_, bool UsePdl_>
struct KernelParams {
using Type = Type_;
using TypeExpW = TypeExpW_;
static constexpr int TopKUnrollFactor = TopKUnrollFactor_;
static constexpr bool UsePdl = UsePdl_;

Type const* inPtr;
Expand Down