diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 8073fd456011..0e17d5cc709d 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -10,15 +10,17 @@ silu_and_mul, ) +from sglang.srt.distributed import get_moe_expert_parallel_world_size from sglang.srt.layers.moe.ep_moe.kernels import ( + cutlass_w4_run_moe_ep_preproess, deepep_ll_get_cutlass_w4a8_moe_mm_data, deepep_permute_triton_kernel, deepep_post_reorder_triton_kernel, deepep_run_moe_deep_preprocess, - post_reorder_triton_kernel_for_cutlass_moe, - pre_reorder_triton_kernel_for_cutlass_moe, - run_moe_ep_preproess, + post_reorder_for_cutlass_moe, + pre_reorder_for_cutlass_moe, silu_and_mul_masked_post_per_tensor_quant_fwd, + silu_mul_static_tensorwise_quant_for_cutlass_moe, ) @@ -44,6 +46,7 @@ def cutlass_w4a8_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, + routed_scaling_factor: float = 1.0, ) -> torch.Tensor: """ This function computes a w4a8-quantized Mixture of Experts (MoE) layer @@ -108,11 +111,11 @@ def cutlass_w4a8_moe( assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1" device = a.device - topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids) + if get_moe_expert_parallel_world_size() > 1: + topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids) - _, src2dst, _ = run_moe_ep_preproess( + src2dst = cutlass_w4_run_moe_ep_preproess( topk_ids, - num_local_experts, ) gateup_input = torch.empty( @@ -121,7 +124,7 @@ def cutlass_w4a8_moe( dtype=torch.float8_e4m3fn, ) - pre_reorder_triton_kernel_for_cutlass_moe[(m,)]( + pre_reorder_for_cutlass_moe( a, gateup_input, src2dst, @@ -129,8 +132,8 @@ def cutlass_w4a8_moe( a1_scale, num_local_experts, topk, + m, k, - BLOCK_SIZE=512, ) # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel, @@ -151,7 +154,7 @@ def cutlass_w4a8_moe( ) c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16) - c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16) + c2 = torch.empty((m * topk, k), device=device, dtype=torch.bfloat16) cutlass_w4a8_moe_mm( c1, @@ -169,13 +172,12 @@ def cutlass_w4a8_moe( topk, ) - intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16) - silu_and_mul(c1, intermediate) - intermediate_q = torch.empty( - intermediate.shape, dtype=torch.float8_e4m3fn, device=device + (m * topk, n), dtype=torch.float8_e4m3fn, device=device + ) + silu_mul_static_tensorwise_quant_for_cutlass_moe( + c1, intermediate_q, a2_scale.float(), expert_offsets[-1:], m * topk, n ) - sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True) cutlass_w4a8_moe_mm( c2, @@ -194,16 +196,18 @@ def cutlass_w4a8_moe( ) output = torch.empty_like(a) - post_reorder_triton_kernel_for_cutlass_moe[(m,)]( + + post_reorder_for_cutlass_moe( c2, output, src2dst, topk_ids, topk_weights, - topk, num_local_experts, + topk, + m, k, - BLOCK_SIZE=512, + routed_scaling_factor, ) return output diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 166f42ea9e64..044c590f2200 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -16,6 +16,60 @@ import triton.language as tl +def _get_launch_config_1d(device, numel): + MAX_THREADS_PER_BLOCK = 1024 + MIN_THREADS_PER_BLOCK = 512 + MAX_WAVES = 8 # empirical numbers + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + max_threads_per_sm = props.max_threads_per_multi_processor + max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK + + block_dim = MAX_THREADS_PER_BLOCK + + def get_num_blocks(block_dim): + return triton.cdiv(numel, block_dim) + + while ( + block_dim > MIN_THREADS_PER_BLOCK + and get_num_blocks(block_dim // 2) <= max_num_blocks + ): + block_dim = block_dim // 2 + + num_blocks = get_num_blocks(block_dim) + grid_dim = min(num_blocks, max_num_blocks * MAX_WAVES) + + return (grid_dim,), block_dim + + +def _get_launch_config_2d(device, m, n): + MAX_THREADS_PER_BLOCK = 1024 + MIN_THREADS_PER_BLOCK = 512 + MAX_WAVES = 8 # empirical numbers + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + max_threads_per_sm = props.max_threads_per_multi_processor + max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK + + block_dim = MAX_THREADS_PER_BLOCK + + def get_num_blocks(block_dim): + return m * triton.cdiv(n, block_dim) + + while ( + block_dim > MIN_THREADS_PER_BLOCK + and get_num_blocks(block_dim // 2) <= max_num_blocks + ): + block_dim = block_dim // 2 + + grid_dim_x = triton.cdiv(n, block_dim) + grid_dim_y = max(min(m, max_num_blocks * MAX_WAVES // grid_dim_x), 1) + + return (grid_dim_y, grid_dim_x), block_dim + + @triton.jit def deepep_permute_triton_kernel( input_ptr, @@ -142,25 +196,17 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1) -def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int): - reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) - - seg_indptr = torch.zeros( - num_local_experts + 1, device=topk_ids.device, dtype=torch.int64 - ) - src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) - - compute_seg_indptr_triton_kernel[(num_local_experts,)]( - reorder_topk_ids, seg_indptr, topk_ids.numel() - ) +def cutlass_w4_run_moe_ep_preproess(topk_ids: torch.Tensor): + _, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) BLOCK_SIZE = 512 grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) compute_src2dst_triton_kernel[grid]( reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE ) - return reorder_topk_ids, src2dst, seg_indptr + return src2dst @triton.jit @@ -172,36 +218,68 @@ def pre_reorder_triton_kernel_for_cutlass_moe( a1_scales_ptr, num_local_experts, topk, + num_tokens, hidden_size, BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty - src_idx_int32 = tl.program_id(0) - src_idx = src_idx_int32.to(tl.int64) - src2dst_ptr = src2dst_ptr + src_idx * topk - topk_ids_ptr = topk_ids_ptr + src_idx * topk - src_ptr = input_ptr + src_idx * hidden_size + if a1_scales_ptr is not None: + a1_scale = 1.0 / tl.load(a1_scales_ptr) + else: + a1_scale = 1.0 - vec = tl.arange(0, BLOCK_SIZE) + offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size - for idx in range(topk): - expert_id = tl.load(topk_ids_ptr + idx) - if expert_id != num_local_experts: - if a1_scales_ptr is not None: - scale = 1.0 / tl.load(a1_scales_ptr) - else: - scale = 1.0 + start_src_idx = tl.program_id(0) + step = tl.num_programs(0) + + for src_idx_int32 in tl.range( + start_src_idx, num_tokens, step, num_stages=NUM_STAGES + ): + src_idx = src_idx_int32.to(tl.int64) + token_src2dst_ptr = src2dst_ptr + src_idx * topk + token_topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr_offs = input_ptr + src_idx * hidden_size + offset + dst_ptr_offs = gateup_input_ptr + offset + in_data = tl.load(src_ptr_offs, mask=mask).to(tl.float32) + out_data = (in_data * a1_scale).to(OutDtype) + for idx in range(topk): + expert_id = tl.load(token_topk_ids_ptr + idx) + if expert_id != num_local_experts: + dst_idx = tl.load(token_src2dst_ptr + idx) + tl.store(dst_ptr_offs + dst_idx * hidden_size, out_data, mask=mask) - dst_idx_int32 = tl.load(src2dst_ptr + idx) - dst_idx = dst_idx_int32.to(tl.int64) - dst_ptr = gateup_input_ptr + dst_idx * hidden_size - for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + vec - mask = offset < hidden_size - in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) - out_data = (in_data * scale).to(OutDtype) - tl.store(dst_ptr + offset, out_data, mask=mask) + +def pre_reorder_for_cutlass_moe( + input, + gateup_input, + src2dst, + topk_ids, + a1_scales, + num_local_experts, + topk, + num_tokens, + hidden_size, +): + grid, block_dim = _get_launch_config_2d(input.device, num_tokens, hidden_size) + + pre_reorder_triton_kernel_for_cutlass_moe[grid]( + input_ptr=input, + gateup_input_ptr=gateup_input, + src2dst_ptr=src2dst, + topk_ids_ptr=topk_ids, + a1_scales_ptr=a1_scales, + num_local_experts=num_local_experts, + topk=topk, + num_tokens=num_tokens, + hidden_size=hidden_size, + BLOCK_SIZE=block_dim, + NUM_STAGES=3, + ) # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py @@ -351,6 +429,62 @@ def silu_and_mul_masked_post_quant_fwd( return +@triton.jit +def silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe( + input_ptr, + output_ptr, + scale_ptr, + num_tokens_tensor_ptr, + intermediate_size, + BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + OutDtype = output_ptr.dtype.element_ty + + num_tokens = tl.load(num_tokens_tensor_ptr) + numel = num_tokens * intermediate_size + gate_ptr = input_ptr + up_ptr = input_ptr + intermediate_size + scale = 1.0 / tl.load(scale_ptr) + + start_idx = tl.program_id(0) * BLOCK_SIZE + step = tl.num_programs(0) * BLOCK_SIZE + + for id in tl.range(start_idx, numel, step, num_stages=NUM_STAGES): + ids = id + tl.arange(0, BLOCK_SIZE) + token_ids = ids // intermediate_size + mask = ids < numel + + offs = ids + token_ids * intermediate_size + gate = tl.load(gate_ptr + offs, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + offs, mask=mask, other=0.0).to(tl.float32) + output = gate / (1 + tl.exp(-gate)) * up * scale + tl.store(output_ptr + ids, output.to(OutDtype), mask=mask) + + +def silu_mul_static_tensorwise_quant_for_cutlass_moe( + input: torch.Tensor, + output: torch.Tensor, + scale: torch.Tensor, + num_tokens_tensor: torch.Tensor, + expected_num_tokens: int, + intermediate_size: int, +): + grid, block_dim = _get_launch_config_1d( + input.device, expected_num_tokens * intermediate_size + ) + + silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe[grid]( + input_ptr=input, + output_ptr=output, + scale_ptr=scale, + num_tokens_tensor_ptr=num_tokens_tensor, + intermediate_size=intermediate_size, + BLOCK_SIZE=block_dim, + NUM_STAGES=3, + ) + + @triton.jit def post_reorder_triton_kernel_for_cutlass_moe( down_output_ptr, @@ -358,38 +492,77 @@ def post_reorder_triton_kernel_for_cutlass_moe( src2dst_ptr, topk_ids_ptr, topk_weights_ptr, - topk, num_local_experts, + topk, + num_tokens, hidden_size, + routed_scaling_factor: float, BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, ): - InDtype = down_output_ptr.dtype.element_ty + OutDtype = output_ptr.dtype.element_ty - src_idx_int32 = tl.program_id(0) - src_idx = src_idx_int32.to(tl.int64) - src2dst_ptr = src2dst_ptr + src_idx * topk - topk_ids_ptr = topk_ids_ptr + src_idx * topk - topk_weights_ptr = topk_weights_ptr + src_idx * topk + offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size - store_ptr = output_ptr + src_idx * hidden_size + down_output_ptr_offs = down_output_ptr + offset + output_ptr_offs = output_ptr + offset - vec = tl.arange(0, BLOCK_SIZE) + start_src_idx = tl.program_id(0) + step = tl.num_programs(0) - for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + vec - mask = offset < hidden_size + for src_idx_int32 in tl.range( + start_src_idx, num_tokens, step, num_stages=NUM_STAGES + ): + src_idx = src_idx_int32.to(tl.int64) + token_src2dst_ptr = src2dst_ptr + src_idx * topk + token_topk_ids_ptr = topk_ids_ptr + src_idx * topk + token_topk_weights_ptr = topk_weights_ptr + src_idx * topk - sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + sum_vec = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for idx in range(topk): - expert_id = tl.load(topk_ids_ptr + idx) + expert_id = tl.load(token_topk_ids_ptr + idx) if expert_id != num_local_experts: - dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx_int32 = tl.load(token_src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) - weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) - load_ptr = down_output_ptr + dst_idx * hidden_size - in_data = tl.load(load_ptr + offset, mask=mask) - sum_vec += in_data * weigh_scale - tl.store(store_ptr + offset, sum_vec, mask=mask) + dst_idx = dst_idx + weight_scale = tl.load(token_topk_weights_ptr + idx).to(tl.float32) + load_ptr_offs = down_output_ptr_offs + dst_idx * hidden_size + in_data = tl.load(load_ptr_offs, mask=mask).to(tl.float32) + sum_vec += in_data * weight_scale + sum_vec *= routed_scaling_factor + store_ptr_offs = output_ptr_offs + src_idx * hidden_size + tl.store(store_ptr_offs, sum_vec.to(OutDtype), mask=mask) + + +def post_reorder_for_cutlass_moe( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + num_local_experts, + topk, + num_tokens, + hidden_size, + routed_scaling_factor: float, +): + grid, block_dim = _get_launch_config_2d(down_output.device, num_tokens, hidden_size) + + post_reorder_triton_kernel_for_cutlass_moe[grid]( + down_output_ptr=down_output, + output_ptr=output, + src2dst_ptr=src2dst, + topk_ids_ptr=topk_ids, + topk_weights_ptr=topk_weights, + num_local_experts=num_local_experts, + topk=topk, + num_tokens=num_tokens, + hidden_size=hidden_size, + routed_scaling_factor=routed_scaling_factor, + BLOCK_SIZE=block_dim, + NUM_STAGES=3, + ) @triton.jit diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index a9aafc97cf58..6ab8f4a6918f 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -270,17 +270,17 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False) # Process input scales - w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item() + w13_input_scale_max = layer.w13_input_scale.max().to(torch.float32).item() new_w13_input_scale = torch.tensor( [w13_input_scale_max], - dtype=dtype, + dtype=torch.float32, device=device, ) layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False) - w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item() + w2_input_scale_max = layer.w2_input_scale.max().to(torch.float32).item() new_w2_input_scale = torch.tensor( - [w2_input_scale_max], dtype=dtype, device=device + [w2_input_scale_max], dtype=torch.float32, device=device ) layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) @@ -324,9 +324,8 @@ def apply( self.problem_sizes2, layer.w13_input_scale, layer.w2_input_scale, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor or 1.0, ) - if self.moe_runner_config.routed_scaling_factor is not None: - output *= self.moe_runner_config.routed_scaling_factor return StandardCombineInput(hidden_states=output) def apply_deepep_ll(