Skip to content

Commit 63233a3

Browse files
committed
fix merge conflict
Signed-off-by: jiahanc <[email protected]>
1 parent 5cfa98b commit 63233a3

File tree

3 files changed

+37
-61
lines changed

3 files changed

+37
-61
lines changed

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -476,19 +476,16 @@ int constexpr getMaxNumExperts(int32_t numExperts) {
476476

477477
////////////////////////////////////////////////////////////////////////////////////////////////////
478478
#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \
479-
extraFlag1, forceFloatInput) \
479+
extraFlag) \
480480
if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \
481-
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \
482-
numThreads, smemSize, stream, extraFlag1, \
483-
forceFloatInput, topk::MaxNumExpertsUnit); \
481+
LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
482+
stream, extraFlag, topk::MaxNumExpertsUnit); \
484483
} else if (data.mNumExperts <= NumDeepseekExperts) { \
485-
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \
486-
numThreads, smemSize, stream, extraFlag1, \
487-
forceFloatInput, NumDeepseekExperts); \
484+
LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
485+
stream, extraFlag, NumDeepseekExperts); \
488486
} else if (data.mNumExperts <= NumKimiK2Experts) { \
489-
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \
490-
numThreads, smemSize, stream, extraFlag1, \
491-
forceFloatInput, NumKimiK2Experts); \
487+
LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
488+
stream, extraFlag, NumKimiK2Experts); \
492489
} else { \
493490
TLLM_LOG_ERROR("Unsupported numExperts"); \
494491
}

include/flashinfer/trtllm/fused_moe/DevKernel.h

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -126,77 +126,55 @@ namespace moe::dev {
126126
}
127127

128128
#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
129-
smemSize, stream, extraFlag) \
129+
smemSize, stream, extraFlag, numExperts) \
130130
if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
131131
data.mDtypeExpW == tg::Dtype::Fp32) { \
132-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, float, extraFlag), kernel, numBlocks, \
133-
numThreads, smemSize, stream); \
132+
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \
133+
numBlocks, numThreads, smemSize, stream); \
134134
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
135135
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
136-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, extraFlag), kernel, \
137-
numBlocks, numThreads, smemSize, stream); \
136+
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \
137+
kernel, numBlocks, numThreads, smemSize, stream); \
138138
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
139139
data.mDtypeExpW == tg::Dtype::Fp32) { \
140-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, extraFlag), kernel, \
141-
numBlocks, numThreads, smemSize, stream); \
140+
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \
141+
kernel, numBlocks, numThreads, smemSize, stream); \
142142
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
143143
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
144-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, extraFlag), \
145-
kernel, numBlocks, numThreads, smemSize, stream); \
144+
LAUNCH_PDL(data, coopLaunch, \
145+
LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \
146+
numBlocks, numThreads, smemSize, stream); \
146147
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
147148
data.mDtypeExpW == tg::Dtype::Fp32) { \
148-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, extraFlag), kernel, \
149-
numBlocks, numThreads, smemSize, stream); \
149+
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \
150+
kernel, numBlocks, numThreads, smemSize, stream); \
150151
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
151152
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
152-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, extraFlag), \
153-
kernel, numBlocks, numThreads, smemSize, stream); \
153+
LAUNCH_PDL(data, coopLaunch, \
154+
LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \
155+
numBlocks, numThreads, smemSize, stream); \
154156
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
155157
data.mDtypeExpW == tg::Dtype::Fp32) { \
156-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, extraFlag), \
157-
kernel, numBlocks, numThreads, smemSize, stream); \
158+
LAUNCH_PDL(data, coopLaunch, \
159+
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \
160+
numBlocks, numThreads, smemSize, stream); \
158161
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
159162
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
160163
LAUNCH_PDL(data, coopLaunch, \
161-
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, extraFlag), kernel, \
162-
numBlocks, numThreads, smemSize, stream); \
164+
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \
165+
kernel, numBlocks, numThreads, smemSize, stream); \
163166
} else { \
164167
FLASHINFER_WARN("Unsupported dtypeExpW"); \
165168
}
166169

167-
#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \
168-
extraFlag) \
169-
if (extraFlag) { \
170-
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
171-
smemSize, stream, true); \
172-
} else { \
173-
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
174-
smemSize, stream, false); \
175-
}
176-
177-
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \
178-
numThreads, smemSize, stream, extraFlag, \
179-
forceFloatInput, numExperts) \
180-
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \
181-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \
182-
numThreads, smemSize, stream); \
183-
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
184-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \
185-
numThreads, smemSize, stream); \
186-
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \
187-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, \
188-
numBlocks, numThreads, smemSize, stream); \
189-
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \
190-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \
191-
kernel, numBlocks, numThreads, smemSize, stream); \
192-
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \
193-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, \
194-
numBlocks, numThreads, smemSize, stream); \
195-
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
196-
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \
197-
kernel, numBlocks, numThreads, smemSize, stream); \
198-
} else { \
199-
FLASHINFER_WARN("Unsupported dtypeExpW"); \
170+
#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
171+
stream, extraFlag, numExperts) \
172+
if (extraFlag) { \
173+
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
174+
smemSize, stream, true, numExperts); \
175+
} else { \
176+
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
177+
smemSize, stream, false, numExperts); \
200178
}
201179

202180
////////////////////////////////////////////////////////////////////////////////////////////////////

include/flashinfer/trtllm/fused_moe/RoutingKernel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ struct Data : public DataBase {
172172
bool mUseRoutingSoftmax;
173173
};
174174

175-
template <typename InputT_, typename BiasT_, typename OutputT_, int MaxNumExperts_, bool UseGroups_, bool UsePdl_>
175+
template <typename InputT_, typename BiasT_, typename OutputT_, int MaxNumExperts_, bool UseGroups_,
176+
bool UsePdl_>
176177
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, MaxNumExperts_, UsePdl_> {
177178
using InputT = InputT_;
178179
using BiasT = BiasT_;

0 commit comments

Comments
 (0)