@@ -59,6 +59,7 @@ __device__ void moe_fused_gate_impl(
5959 int64_t topk,
6060 int64_t num_fused_shared_experts,
6161 double routed_scaling_factor,
62+ bool apply_routed_scaling_factor_on_output,
6263 Params params) {
6364 int tidx = threadIdx .x ;
6465 int64_t thread_row =
@@ -248,6 +249,9 @@ __device__ void moe_fused_gate_impl(
248249 for (int ii = 0 ; ii < topk; ++ii) {
249250 int64_t const idx = topk * thread_row + ii;
250251 output_ptr[idx] = output_ptr[idx] / output_sum;
252+ if (apply_routed_scaling_factor_on_output) {
253+ output_ptr[idx] *= routed_scaling_factor;
254+ }
251255 }
252256 }
253257}
@@ -282,7 +286,8 @@ __global__ void moe_fused_gate_kernel(
282286 int64_t topk_group,
283287 int64_t topk,
284288 int64_t num_fused_shared_experts,
285- double routed_scaling_factor) {
289+ double routed_scaling_factor,
290+ bool apply_routed_scaling_factor_on_output) {
286291 KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
287292 moe_fused_gate_impl<T>(
288293 input,
@@ -294,6 +299,7 @@ __global__ void moe_fused_gate_kernel(
294299 topk,
295300 num_fused_shared_experts,
296301 routed_scaling_factor,
302+ apply_routed_scaling_factor_on_output,
297303 params);
298304}
299305
@@ -314,7 +320,8 @@ __global__ void moe_fused_gate_kernel(
314320 topk_group, \
315321 topk, \
316322 num_fused_shared_experts, \
317- routed_scaling_factor); \
323+ routed_scaling_factor, \
324+ apply_routed_scaling_factor_on_output); \
318325 dispatched = true ; \
319326 } while (0 )
320327
@@ -342,7 +349,8 @@ __global__ void moe_fused_gate_kernel_dynamic(
342349 int64_t topk_group,
343350 int64_t topk,
344351 int64_t num_fused_shared_experts,
345- double routed_scaling_factor) {
352+ double routed_scaling_factor,
353+ bool apply_routed_scaling_factor_on_output) {
346354 KernelParamsDynamic params;
347355 params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
348356 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -361,6 +369,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
361369 topk,
362370 num_fused_shared_experts,
363371 routed_scaling_factor,
372+ apply_routed_scaling_factor_on_output,
364373 params);
365374}
366375
@@ -374,7 +383,8 @@ std::vector<at::Tensor> moe_fused_gate(
374383 int64_t topk_group,
375384 int64_t topk,
376385 int64_t num_fused_shared_experts,
377- double routed_scaling_factor) {
386+ double routed_scaling_factor,
387+ bool apply_routed_scaling_factor_on_output) {
378388 int64_t num_rows = input.size (0 );
379389 int32_t num_experts = input.size (1 );
380390 auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCUDA );
@@ -473,7 +483,8 @@ std::vector<at::Tensor> moe_fused_gate(
473483 topk_group,
474484 topk,
475485 num_fused_shared_experts,
476- routed_scaling_factor);
486+ routed_scaling_factor,
487+ apply_routed_scaling_factor_on_output);
477488 } else if (input.scalar_type () == at::kHalf ) {
478489 moe_fused_gate_kernel_dynamic<float16_t ><<<num_blocks, block_dim, 0 , stream>>> (
479490 input.data_ptr (),
@@ -486,7 +497,8 @@ std::vector<at::Tensor> moe_fused_gate(
486497 topk_group,
487498 topk,
488499 num_fused_shared_experts,
489- routed_scaling_factor);
500+ routed_scaling_factor,
501+ apply_routed_scaling_factor_on_output);
490502 } else if (input.scalar_type () == at::kFloat ) {
491503 moe_fused_gate_kernel_dynamic<float32_t ><<<num_blocks, block_dim, 0 , stream>>> (
492504 input.data_ptr (),
@@ -499,7 +511,8 @@ std::vector<at::Tensor> moe_fused_gate(
499511 topk_group,
500512 topk,
501513 num_fused_shared_experts,
502- routed_scaling_factor);
514+ routed_scaling_factor,
515+ apply_routed_scaling_factor_on_output);
503516 } else {
504517 TORCH_CHECK (false , " Unsupported data type for moe_fused_gate" );
505518 }
0 commit comments