@@ -126,77 +126,55 @@ namespace moe::dev {
126126 }
127127
128128#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
129- smemSize, stream, extraFlag) \
129+ smemSize, stream, extraFlag, numExperts) \
130130 if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
131131 data.mDtypeExpW == tg::Dtype::Fp32) { \
132- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , float , float , extraFlag), kernel, numBlocks, \
133- numThreads, smemSize, stream); \
132+ LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , float , float , numExperts, extraFlag), kernel, \
133+ numBlocks, numThreads, smemSize, stream); \
134134 } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
135135 data.mDtypeExpW == tg::Dtype::Bfloat16) { \
136- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , float , __nv_bfloat16, extraFlag), kernel, \
137- numBlocks, numThreads, smemSize, stream); \
136+ LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , float , __nv_bfloat16, numExperts, extraFlag), \
137+ kernel, numBlocks, numThreads, smemSize, stream); \
138138 } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
139139 data.mDtypeExpW == tg::Dtype::Fp32) { \
140- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , __nv_bfloat16, float , extraFlag), kernel, \
141- numBlocks, numThreads, smemSize, stream); \
140+ LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , __nv_bfloat16, float , numExperts, extraFlag), \
141+ kernel, numBlocks, numThreads, smemSize, stream); \
142142 } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
143143 data.mDtypeExpW == tg::Dtype::Bfloat16) { \
144- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , __nv_bfloat16, __nv_bfloat16, extraFlag), \
145- kernel, numBlocks, numThreads, smemSize, stream); \
144+ LAUNCH_PDL (data, coopLaunch, \
145+ LAUNCH_ESC (float , __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \
146+ numBlocks, numThreads, smemSize, stream); \
146147 } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
147148 data.mDtypeExpW == tg::Dtype::Fp32) { \
148- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (__nv_bfloat16, float , float , extraFlag), kernel, \
149- numBlocks, numThreads, smemSize, stream); \
149+ LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (__nv_bfloat16, float , float , numExperts, extraFlag), \
150+ kernel, numBlocks, numThreads, smemSize, stream); \
150151 } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
151152 data.mDtypeExpW == tg::Dtype::Bfloat16) { \
152- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (__nv_bfloat16, float , __nv_bfloat16, extraFlag), \
153- kernel, numBlocks, numThreads, smemSize, stream); \
153+ LAUNCH_PDL (data, coopLaunch, \
154+ LAUNCH_ESC (__nv_bfloat16, float , __nv_bfloat16, numExperts, extraFlag), kernel, \
155+ numBlocks, numThreads, smemSize, stream); \
154156 } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
155157 data.mDtypeExpW == tg::Dtype::Fp32) { \
156- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, float , extraFlag), \
157- kernel, numBlocks, numThreads, smemSize, stream); \
158+ LAUNCH_PDL (data, coopLaunch, \
159+ LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, float , numExperts, extraFlag), kernel, \
160+ numBlocks, numThreads, smemSize, stream); \
158161 } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
159162 data.mDtypeExpW == tg::Dtype::Bfloat16) { \
160163 LAUNCH_PDL (data, coopLaunch, \
161- LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, extraFlag), kernel, \
162- numBlocks, numThreads, smemSize, stream); \
164+ LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \
165+ kernel, numBlocks, numThreads, smemSize, stream); \
163166 } else { \
164167 FLASHINFER_WARN (" Unsupported dtypeExpW" ); \
165168 }
166169
167- #define LAUNCH_ROUTING_DEEPSEEK (data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \
168- extraFlag) \
169- if (extraFlag) { \
170- LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
171- smemSize, stream, true ); \
172- } else { \
173- LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
174- smemSize, stream, false ); \
175- }
176-
177- #define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT (data, coopLaunch, kernel, numBlocks, \
178- numThreads, smemSize, stream, extraFlag, \
179- forceFloatInput, numExperts) \
180- if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \
181- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , float , numExperts, true ), kernel, numBlocks, \
182- numThreads, smemSize, stream); \
183- } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
184- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , float , numExperts, false ), kernel, numBlocks, \
185- numThreads, smemSize, stream); \
186- } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \
187- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , __nv_bfloat16, numExperts, true ), kernel, \
188- numBlocks, numThreads, smemSize, stream); \
189- } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \
190- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, numExperts, true ), \
191- kernel, numBlocks, numThreads, smemSize, stream); \
192- } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \
193- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (float , __nv_bfloat16, numExperts, false ), kernel, \
194- numBlocks, numThreads, smemSize, stream); \
195- } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
196- LAUNCH_PDL (data, coopLaunch, LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, numExperts, false ), \
197- kernel, numBlocks, numThreads, smemSize, stream); \
198- } else { \
199- FLASHINFER_WARN (" Unsupported dtypeExpW" ); \
170+ #define LAUNCH_ROUTING_DEEPSEEK_IMPL (data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
171+ stream, extraFlag, numExperts) \
172+ if (extraFlag) { \
173+ LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
174+ smemSize, stream, true , numExperts); \
175+ } else { \
176+ LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
177+ smemSize, stream, false , numExperts); \
200178 }
201179
202180// //////////////////////////////////////////////////////////////////////////////////////////////////
0 commit comments