diff --git a/python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh b/python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh index 19677e0e7ed6..183a31d485a5 100644 --- a/python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh +++ b/python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh @@ -157,20 +157,14 @@ __global__ void grouped_topk_single_group_kernel( __syncwarp(); } - // Phase 3: renormalize and write output + // Phase 3: renormalize and write output. `__shfl_xor_sync` requires every + // lane named in the mask (0xffffffff) to execute it, so we pad inactive + // lanes with 0 and run warp_sum_f32 on all 32 lanes uniformly. + float weight = (lane_id < topk) ? selected_weights[lane_id] : 0.0f; + float divisor = renormalize ? warp_sum_f32(weight) + 1e-20f : 1.0f; if (lane_id < topk) { - float weight = selected_weights[lane_id]; - float final_weight = weight * scaling_factor; - - if (renormalize) { - // Warp-level sum of selected weights (only lanes < topk contribute) - float partial = (lane_id < topk) ? weight : 0.0f; - float total = warp_sum_f32(partial); - final_weight = weight * scaling_factor / (total + 1e-20f); - } - out_ids[lane_id] = selected_ids[lane_id]; - out_vals[lane_id] = final_weight; + out_vals[lane_id] = weight * scaling_factor / divisor; } } diff --git a/test/registered/models/test_nvidia_nemotron_3_nano.py b/test/registered/models/test_nvidia_nemotron_3_nano.py index 4834265a2b33..75933f4a4ff8 100644 --- a/test/registered/models/test_nvidia_nemotron_3_nano.py +++ b/test/registered/models/test_nvidia_nemotron_3_nano.py @@ -4,11 +4,7 @@ from sglang.test.kits.lm_eval_kit import LMEvalMixin from sglang.test.server_fixtures.default_fixture import DefaultServerBase -register_cuda_ci( - est_time=564, - suite="stage-b-test-2-gpu-large", - disabled="Temporarily disabled; failing on main.", -) +register_cuda_ci(est_time=564, suite="stage-b-test-2-gpu-large") NEMOTRON_3_NANO_THINKING_ARGS = [ "--trust-remote-code",