From 42b269da2a80f7c060403e3722f725da7494e9a9 Mon Sep 17 00:00:00 2001 From: yuhyao <827623970@qq.com> Date: Thu, 4 Sep 2025 18:10:37 +0800 Subject: [PATCH 1/3] Optimize by reducing unnecessary kernels. --- python/sglang/srt/layers/moe/cutlass_w4a8_moe.py | 11 ++++++----- python/sglang/srt/layers/moe/ep_moe/kernels.py | 16 ++++------------ python/sglang/srt/layers/quantization/w4afp8.py | 8 ++++---- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 8073fd456011..242af809d91b 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -10,14 +10,15 @@ silu_and_mul, ) +from sglang.srt.distributed import get_moe_expert_parallel_world_size from sglang.srt.layers.moe.ep_moe.kernels import ( + cutlass_w4_run_moe_ep_preproess, deepep_ll_get_cutlass_w4a8_moe_mm_data, deepep_permute_triton_kernel, deepep_post_reorder_triton_kernel, deepep_run_moe_deep_preprocess, post_reorder_triton_kernel_for_cutlass_moe, pre_reorder_triton_kernel_for_cutlass_moe, - run_moe_ep_preproess, silu_and_mul_masked_post_per_tensor_quant_fwd, ) @@ -108,11 +109,11 @@ def cutlass_w4a8_moe( assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1" device = a.device - topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids) + if get_moe_expert_parallel_world_size() > 1: + topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids) - _, src2dst, _ = run_moe_ep_preproess( + src2dst = cutlass_w4_run_moe_ep_preproess( topk_ids, - num_local_experts, ) gateup_input = torch.empty( @@ -151,7 +152,7 @@ def cutlass_w4a8_moe( ) c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16) - c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16) + c2 = torch.empty((m * topk, k), device=device, dtype=torch.bfloat16) cutlass_w4a8_moe_mm( c1, diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 166f42ea9e64..dd36432fcf3a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -142,25 +142,17 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1) -def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int): - reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) - - seg_indptr = torch.zeros( - num_local_experts + 1, device=topk_ids.device, dtype=torch.int64 - ) - src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) - - compute_seg_indptr_triton_kernel[(num_local_experts,)]( - reorder_topk_ids, seg_indptr, topk_ids.numel() - ) +def cutlass_w4_run_moe_ep_preproess(topk_ids: torch.Tensor): + _, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) BLOCK_SIZE = 512 grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) compute_src2dst_triton_kernel[grid]( reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE ) - return reorder_topk_ids, src2dst, seg_indptr + return src2dst @triton.jit diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index a9aafc97cf58..80fc42f01f2d 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -270,17 +270,17 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False) # Process input scales - w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item() + w13_input_scale_max = layer.w13_input_scale.max().to(torch.float32).item() new_w13_input_scale = torch.tensor( [w13_input_scale_max], - dtype=dtype, + dtype=torch.float32, device=device, ) layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False) - w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item() + w2_input_scale_max = layer.w2_input_scale.max().to(torch.float32).item() new_w2_input_scale = torch.tensor( - [w2_input_scale_max], dtype=dtype, device=device + [w2_input_scale_max], dtype=torch.float32, device=device ) layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) From 48ea51d487c0599a122bde4bceb1e36e58656591 Mon Sep 17 00:00:00 2001 From: yuhyao <827623970@qq.com> Date: Fri, 14 Nov 2025 19:32:42 +0800 Subject: [PATCH 2/3] Optimize pre & post-reorder kernels and fuse some kernels. --- .../sglang/srt/layers/moe/cutlass_w4a8_moe.py | 27 +- .../sglang/srt/layers/moe/ep_moe/kernels.py | 259 +++++++++++++++--- .../sglang/srt/layers/quantization/w4afp8.py | 3 +- 3 files changed, 232 insertions(+), 57 deletions(-) diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 242af809d91b..2d35c0216fe4 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -17,9 +17,10 @@ deepep_permute_triton_kernel, deepep_post_reorder_triton_kernel, deepep_run_moe_deep_preprocess, - post_reorder_triton_kernel_for_cutlass_moe, - pre_reorder_triton_kernel_for_cutlass_moe, + post_reorder_for_cutlass_moe, + pre_reorder_for_cutlass_moe, silu_and_mul_masked_post_per_tensor_quant_fwd, + silu_mul_static_tensorwise_quant_for_cutlass_moe, ) @@ -45,6 +46,7 @@ def cutlass_w4a8_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, + routed_scaling_factor: float = 1.0, ) -> torch.Tensor: """ This function computes a w4a8-quantized Mixture of Experts (MoE) layer @@ -122,7 +124,7 @@ def cutlass_w4a8_moe( dtype=torch.float8_e4m3fn, ) - pre_reorder_triton_kernel_for_cutlass_moe[(m,)]( + pre_reorder_for_cutlass_moe( a, gateup_input, src2dst, @@ -130,8 +132,8 @@ def cutlass_w4a8_moe( a1_scale, num_local_experts, topk, + m, k, - BLOCK_SIZE=512, ) # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel, @@ -170,13 +172,12 @@ def cutlass_w4a8_moe( topk, ) - intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16) - silu_and_mul(c1, intermediate) - intermediate_q = torch.empty( - intermediate.shape, dtype=torch.float8_e4m3fn, device=device + (m * topk, n), dtype=torch.float8_e4m3fn, device=device + ) + silu_mul_static_tensorwise_quant_for_cutlass_moe( + c1, intermediate_q, a2_scale.float(), m * topk, n ) - sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True) cutlass_w4a8_moe_mm( c2, @@ -195,16 +196,18 @@ def cutlass_w4a8_moe( ) output = torch.empty_like(a) - post_reorder_triton_kernel_for_cutlass_moe[(m,)]( + + post_reorder_for_cutlass_moe( c2, output, src2dst, topk_ids, topk_weights, - topk, num_local_experts, + topk, + m, k, - BLOCK_SIZE=512, + routed_scaling_factor, ) return output diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index dd36432fcf3a..05a9e68aed9b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -16,6 +16,60 @@ import triton.language as tl +def _get_launch_config_1d(device, numel): + MAX_THREADS_PER_BLOCK = 1024 + MIN_THREADS_PER_BLOCK = 512 + MAX_WAVES = 8 # empirical numbers + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + max_threads_per_sm = props.max_threads_per_multi_processor + max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK + + block_dim = MAX_THREADS_PER_BLOCK + + def get_num_blocks(block_dim): + return triton.cdiv(numel, block_dim) + + while ( + block_dim > MIN_THREADS_PER_BLOCK + and get_num_blocks(block_dim // 2) <= max_num_blocks + ): + block_dim = block_dim // 2 + + num_blocks = get_num_blocks(block_dim) + grid_dim = min(num_blocks, max_num_blocks * MAX_WAVES) + + return (grid_dim,), block_dim + + +def _get_launch_config_2d(device, m, n): + MAX_THREADS_PER_BLOCK = 1024 + MIN_THREADS_PER_BLOCK = 512 + MAX_WAVES = 8 # empirical numbers + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + max_threads_per_sm = props.max_threads_per_multi_processor + max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK + + block_dim = MAX_THREADS_PER_BLOCK + + def get_num_blocks(block_dim): + return m * triton.cdiv(n, block_dim) + + while ( + block_dim > MIN_THREADS_PER_BLOCK + and get_num_blocks(block_dim // 2) <= max_num_blocks + ): + block_dim = block_dim // 2 + + grid_dim_x = triton.cdiv(n, block_dim) + grid_dim_y = max(min(m, max_num_blocks * MAX_WAVES // grid_dim_x), 1) + + return (grid_dim_y, grid_dim_x), block_dim + + @triton.jit def deepep_permute_triton_kernel( input_ptr, @@ -164,36 +218,65 @@ def pre_reorder_triton_kernel_for_cutlass_moe( a1_scales_ptr, num_local_experts, topk, + num_tokens, hidden_size, BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty - src_idx_int32 = tl.program_id(0) - src_idx = src_idx_int32.to(tl.int64) - src2dst_ptr = src2dst_ptr + src_idx * topk - topk_ids_ptr = topk_ids_ptr + src_idx * topk - src_ptr = input_ptr + src_idx * hidden_size + if a1_scales_ptr is not None: + a1_scale = 1.0 / tl.load(a1_scales_ptr) + else: + a1_scale = 1.0 - vec = tl.arange(0, BLOCK_SIZE) + offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size - for idx in range(topk): - expert_id = tl.load(topk_ids_ptr + idx) - if expert_id != num_local_experts: - if a1_scales_ptr is not None: - scale = 1.0 / tl.load(a1_scales_ptr) - else: - scale = 1.0 + for src_idx_int32 in tl.range( + tl.program_id(0), num_tokens, tl.num_programs(0), num_stages=NUM_STAGES + ): + src_idx = src_idx_int32.to(tl.int64) + token_src2dst_ptr = src2dst_ptr + src_idx * topk + token_topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr_offs = input_ptr + src_idx * hidden_size + offset + dst_ptr_offs = gateup_input_ptr + offset + in_data = tl.load(src_ptr_offs, mask=mask).to(tl.float32) + out_data = (in_data * a1_scale).to(OutDtype) + for idx in range(topk): + expert_id = tl.load(token_topk_ids_ptr + idx) + if expert_id != num_local_experts: + dst_idx = tl.load(token_src2dst_ptr + idx) + tl.store(dst_ptr_offs + dst_idx * hidden_size, out_data, mask=mask) - dst_idx_int32 = tl.load(src2dst_ptr + idx) - dst_idx = dst_idx_int32.to(tl.int64) - dst_ptr = gateup_input_ptr + dst_idx * hidden_size - for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + vec - mask = offset < hidden_size - in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) - out_data = (in_data * scale).to(OutDtype) - tl.store(dst_ptr + offset, out_data, mask=mask) + +def pre_reorder_for_cutlass_moe( + input, + gateup_input, + src2dst, + topk_ids, + a1_scales, + num_local_experts, + topk, + num_tokens, + hidden_size, +): + grid, block_dim = _get_launch_config_2d(input.device, num_tokens, hidden_size) + + pre_reorder_triton_kernel_for_cutlass_moe[grid]( + input_ptr=input, + gateup_input_ptr=gateup_input, + src2dst_ptr=src2dst, + topk_ids_ptr=topk_ids, + a1_scales_ptr=a1_scales, + num_local_experts=num_local_experts, + topk=topk, + num_tokens=num_tokens, + hidden_size=hidden_size, + BLOCK_SIZE=block_dim, + NUM_STAGES=3, + ) # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py @@ -343,6 +426,60 @@ def silu_and_mul_masked_post_quant_fwd( return +@triton.jit +def silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe( + input_ptr, + output_ptr, + scale_ptr, + num_tokens, + intermediate_size, + BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + OutDtype = output_ptr.dtype.element_ty + + numel = num_tokens * intermediate_size + gate_ptr = input_ptr + up_ptr = input_ptr + intermediate_size + scale = 1.0 / tl.load(scale_ptr) + + start_idx = tl.program_id(0) * BLOCK_SIZE + step = tl.num_programs(0) * BLOCK_SIZE + + for id in tl.range(start_idx, numel, step, num_stages=NUM_STAGES): + ids = id + tl.arange(0, BLOCK_SIZE) + token_ids = ids // intermediate_size + mask = ids < numel + + offs = ids + token_ids * intermediate_size + gate = tl.load(gate_ptr + offs, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + offs, mask=mask, other=0.0).to(tl.float32) + output = gate / (1 + tl.exp(-gate)) * up * scale + tl.store(output_ptr + ids, output.to(OutDtype), mask=mask) + + +def silu_mul_static_tensorwise_quant_for_cutlass_moe( + input: torch.Tensor, + output: torch.Tensor, + scale: torch.Tensor, + num_tokens: int, + intermediate_size: int, +): + grid, block_dim = _get_launch_config_1d( + input.device, num_tokens * intermediate_size + ) + + silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe[grid]( + input_ptr=input, + output_ptr=output, + scale_ptr=scale, + num_tokens=num_tokens, + intermediate_size=intermediate_size, + BLOCK_SIZE=block_dim, + NUM_STAGES=3, + ) + + @triton.jit def post_reorder_triton_kernel_for_cutlass_moe( down_output_ptr, @@ -350,38 +487,74 @@ def post_reorder_triton_kernel_for_cutlass_moe( src2dst_ptr, topk_ids_ptr, topk_weights_ptr, - topk, num_local_experts, + topk, + num_tokens, hidden_size, + routed_scaling_factor: float, BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, ): - InDtype = down_output_ptr.dtype.element_ty + OutDtype = output_ptr.dtype.element_ty - src_idx_int32 = tl.program_id(0) - src_idx = src_idx_int32.to(tl.int64) - src2dst_ptr = src2dst_ptr + src_idx * topk - topk_ids_ptr = topk_ids_ptr + src_idx * topk - topk_weights_ptr = topk_weights_ptr + src_idx * topk + offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size - store_ptr = output_ptr + src_idx * hidden_size - - vec = tl.arange(0, BLOCK_SIZE) + down_output_ptr_offs = down_output_ptr + offset + output_ptr_offs = output_ptr + offset - for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + vec - mask = offset < hidden_size + for src_idx_int32 in tl.range( + tl.program_id(0), num_tokens, tl.num_programs(0), num_stages=NUM_STAGES + ): + src_idx = src_idx_int32.to(tl.int64) + token_src2dst_ptr = src2dst_ptr + src_idx * topk + token_topk_ids_ptr = topk_ids_ptr + src_idx * topk + token_topk_weights_ptr = topk_weights_ptr + src_idx * topk - sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + sum_vec = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for idx in range(topk): - expert_id = tl.load(topk_ids_ptr + idx) + expert_id = tl.load(token_topk_ids_ptr + idx) if expert_id != num_local_experts: - dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx_int32 = tl.load(token_src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) - weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) - load_ptr = down_output_ptr + dst_idx * hidden_size - in_data = tl.load(load_ptr + offset, mask=mask) - sum_vec += in_data * weigh_scale - tl.store(store_ptr + offset, sum_vec, mask=mask) + dst_idx = dst_idx + weight_scale = tl.load(token_topk_weights_ptr + idx).to(tl.float32) + load_ptr_offs = down_output_ptr_offs + dst_idx * hidden_size + in_data = tl.load(load_ptr_offs, mask=mask).to(tl.float32) + sum_vec += in_data * weight_scale + sum_vec *= routed_scaling_factor + store_ptr_offs = output_ptr_offs + src_idx * hidden_size + tl.store(store_ptr_offs, sum_vec.to(OutDtype), mask=mask) + + +def post_reorder_for_cutlass_moe( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + num_local_experts, + topk, + num_tokens, + hidden_size, + routed_scaling_factor: float, +): + grid, block_dim = _get_launch_config_2d(down_output.device, num_tokens, hidden_size) + + post_reorder_triton_kernel_for_cutlass_moe[grid]( + down_output_ptr=down_output, + output_ptr=output, + src2dst_ptr=src2dst, + topk_ids_ptr=topk_ids, + topk_weights_ptr=topk_weights, + num_local_experts=num_local_experts, + topk=topk, + num_tokens=num_tokens, + hidden_size=hidden_size, + routed_scaling_factor=routed_scaling_factor, + BLOCK_SIZE=block_dim, + NUM_STAGES=3, + ) @triton.jit diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 80fc42f01f2d..6ab8f4a6918f 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -324,9 +324,8 @@ def apply( self.problem_sizes2, layer.w13_input_scale, layer.w2_input_scale, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor or 1.0, ) - if self.moe_runner_config.routed_scaling_factor is not None: - output *= self.moe_runner_config.routed_scaling_factor return StandardCombineInput(hidden_states=output) def apply_deepep_ll( From be30f602bd4767d0ad12d01350ddb00f53b8007b Mon Sep 17 00:00:00 2001 From: yuhyao <827623970@qq.com> Date: Fri, 14 Nov 2025 19:47:08 +0800 Subject: [PATCH 3/3] Optimize silu_mul for ep. --- .../sglang/srt/layers/moe/cutlass_w4a8_moe.py | 2 +- .../sglang/srt/layers/moe/ep_moe/kernels.py | 20 +++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 2d35c0216fe4..0e17d5cc709d 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -176,7 +176,7 @@ def cutlass_w4a8_moe( (m * topk, n), dtype=torch.float8_e4m3fn, device=device ) silu_mul_static_tensorwise_quant_for_cutlass_moe( - c1, intermediate_q, a2_scale.float(), m * topk, n + c1, intermediate_q, a2_scale.float(), expert_offsets[-1:], m * topk, n ) cutlass_w4a8_moe_mm( diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 05a9e68aed9b..044c590f2200 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -233,8 +233,11 @@ def pre_reorder_triton_kernel_for_cutlass_moe( offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size + start_src_idx = tl.program_id(0) + step = tl.num_programs(0) + for src_idx_int32 in tl.range( - tl.program_id(0), num_tokens, tl.num_programs(0), num_stages=NUM_STAGES + start_src_idx, num_tokens, step, num_stages=NUM_STAGES ): src_idx = src_idx_int32.to(tl.int64) token_src2dst_ptr = src2dst_ptr + src_idx * topk @@ -431,13 +434,14 @@ def silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe( input_ptr, output_ptr, scale_ptr, - num_tokens, + num_tokens_tensor_ptr, intermediate_size, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr, ): OutDtype = output_ptr.dtype.element_ty + num_tokens = tl.load(num_tokens_tensor_ptr) numel = num_tokens * intermediate_size gate_ptr = input_ptr up_ptr = input_ptr + intermediate_size @@ -462,18 +466,19 @@ def silu_mul_static_tensorwise_quant_for_cutlass_moe( input: torch.Tensor, output: torch.Tensor, scale: torch.Tensor, - num_tokens: int, + num_tokens_tensor: torch.Tensor, + expected_num_tokens: int, intermediate_size: int, ): grid, block_dim = _get_launch_config_1d( - input.device, num_tokens * intermediate_size + input.device, expected_num_tokens * intermediate_size ) silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe[grid]( input_ptr=input, output_ptr=output, scale_ptr=scale, - num_tokens=num_tokens, + num_tokens_tensor_ptr=num_tokens_tensor, intermediate_size=intermediate_size, BLOCK_SIZE=block_dim, NUM_STAGES=3, @@ -503,8 +508,11 @@ def post_reorder_triton_kernel_for_cutlass_moe( down_output_ptr_offs = down_output_ptr + offset output_ptr_offs = output_ptr + offset + start_src_idx = tl.program_id(0) + step = tl.num_programs(0) + for src_idx_int32 in tl.range( - tl.program_id(0), num_tokens, tl.num_programs(0), num_stages=NUM_STAGES + start_src_idx, num_tokens, step, num_stages=NUM_STAGES ): src_idx = src_idx_int32.to(tl.int64) token_src2dst_ptr = src2dst_ptr + src_idx * topk