Skip to content

Commit 721371d

Browse files
committed
Revise the code according to comment
1 parent 1b8587c commit 721371d

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,19 @@ __global__ void routingMainKernel(KernelParams params) {
189189
__syncthreads();
190190
if (warpIdx == 0) {
191191
int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1;
192-
float intermidiateScore[NumInterTopKPerThread];
193-
int32_t intermidiateExpert[NumInterTopKPerThread];
192+
float intermediateScore[NumInterTopKPerThread];
193+
int32_t intermediateExpert[NumInterTopKPerThread];
194194
for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) {
195195
int ii = i / WarpSize;
196196
if (i < NumInterTopK) {
197-
intermidiateScore[ii] = smemInterTopScores[i];
198-
intermidiateExpert[ii] = smemInterTopExperts[i];
197+
intermediateScore[ii] = smemInterTopScores[i];
198+
intermediateExpert[ii] = smemInterTopExperts[i];
199199
} else {
200-
intermidiateScore[ii] = invalidScoreFloat;
201-
intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1;
200+
intermediateScore[ii] = invalidScoreFloat;
201+
intermediateExpert[ii] = KernelParams::MaxNumExperts - 1;
202202
}
203203
}
204-
topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert,
204+
topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert,
205205
/* minValue */ invalidScoreFloat, params.mTopK);
206206
}
207207
} else {

include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,13 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts)
419419
PackedScoreIdx<OutputT> scoreIdx;
420420
int idx;
421421
if (params.mPtrTopKIds != nullptr) {
422+
// If params.mPtrTopKIds != nullptr, we don't need to store the weights
422423
idx = params.mPtrTopKIds[expandedIdx];
423424
} else {
424-
// If params.mPtrTopKIds != nullptr, we don't need to store the weights
425+
scoreIdx = params.mPtrTopKPacked[expandedIdx];
426+
idx = scoreIdx.idx;
425427
if (params.mPtrTopKWeights != nullptr) {
426-
scoreIdx = params.mPtrTopKPacked[expandedIdx];
427-
idx = scoreIdx.idx;
428+
// For now, params.mPtrTopKWeights shouldn't be nullptr.
428429
params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
429430
}
430431
}

include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ struct TopKRedType {
5353
static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) {
5454
auto valueBits = cub::Traits<TypeExpW>::TwiddleIn(
5555
reinterpret_cast<typename cub::Traits<TypeExpW>::UnsignedBits&>(val));
56-
TypeCmp compactTmp;
57-
memcpy(&compactTmp, &valueBits, sizeof(valueBits));
56+
TypeCmp compactTmp = valueBits;
5857
compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx));
5958
// Use 65535 minus idx to give higher priority to elements with smaller indices.
6059
return compactTmp;
@@ -203,6 +202,9 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const
203202
static_assert(K < WarpSize, "Top K must have K < WarpSize");
204203
static_assert(N > 0, "Top K must have N > 0");
205204
static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512");
205+
static_assert(
206+
N <= 4 || N % 4 == 0,
207+
"Only support candidates number is a multiple of 4*32=128 or less than or equal to 4");
206208
using RedType = TopKRedType<Type>;
207209

208210
if constexpr (N <= 4) {

include/flashinfer/trtllm/fused_moe/runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ struct MoERunnerArgs {
265265
// Hidden dimension output of MoE block. It is not padded.
266266
// If not provided it is the same as hidden_size.
267267
std::optional<int32_t> hidden_size_output;
268-
// TODO: only compiled routing kernel supports top_k = 8
268+
// Now support top_k<=10
269269
int32_t top_k{0};
270270
int32_t n_group{0};
271271
// TODO: only compiled routing kernel supports topk_group = 4

0 commit comments

Comments
 (0)