diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py b/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py new file mode 100644 index 00000000000..13ff617448e --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py @@ -0,0 +1,199 @@ +import torch +import triton +import triton.language as tl +from triton.testing import do_bench + + +# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py +@triton.jit +def _moe_sum_reduce_kernel( + input_ptr, + input_stride_0, + input_stride_1, + input_stride_2, + output_ptr, + output_stride_0, + output_stride_1, + token_num: int, + topk_num: int, + hidden_dim: int, + routed_scaling_factor: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DIM: tl.constexpr, + NUM_STAGE: tl.constexpr, +): + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) + + token_block_id = tl.program_id(0) + dim_block_id = tl.program_id(1) + + token_start = token_block_id * BLOCK_M + token_end = min((token_block_id + 1) * BLOCK_M, token_num) + + dim_start = dim_block_id * BLOCK_DIM + dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) + + offs_dim = dim_start + tl.arange(0, BLOCK_DIM) + + for token_index in range(token_start, token_end): + accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) + input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim + for i in tl.range(0, topk_num, num_stages=NUM_STAGE): + tmp = tl.load( + input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 + ) + accumulator += tmp + accumulator = accumulator * routed_scaling_factor + store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim + tl.store( + store_t_ptr, + accumulator.to(input_ptr.dtype.element_ty), + mask=offs_dim < dim_end, + ) + + +def moe_sum_reduce( + input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float +): + assert input.is_contiguous() + assert output.is_contiguous() + + token_num, topk_num, hidden_dim = input.shape + assert output.shape[0] == token_num and output.shape[1] == hidden_dim + + BLOCK_M = 1 + BLOCK_DIM = 2048 + NUM_STAGE = 1 + num_warps = 8 + + grid = ( + triton.cdiv(token_num, BLOCK_M), + triton.cdiv(hidden_dim, BLOCK_DIM), + ) + + _moe_sum_reduce_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + token_num=token_num, + topk_num=topk_num, + hidden_dim=hidden_dim, + routed_scaling_factor=routed_scaling_factor, + BLOCK_M=BLOCK_M, + BLOCK_DIM=BLOCK_DIM, + NUM_STAGE=NUM_STAGE, + num_warps=num_warps, + ) + return + + +def compute_sum_scaled_baseline( + x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float +) -> torch.Tensor: + torch.sum(x, dim=1, out=out) + out.mul_(routed_scaling_factor) + return out + + +@torch.compile +def compute_sum_scaled_compiled( + x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float +) -> torch.Tensor: + torch.sum(x * routed_scaling_factor, dim=1, out=out) + return out + + +def get_benchmark(): + num_tokens_range = [2**i for i in range(0, 13)] + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=num_tokens_range, + line_arg="version", + line_vals=["baseline", "compiled", "triton"], + line_names=["Original", "TorchCompile", "TritonKernel"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name="sum_scaled_performance", + args={}, + ) + ) + def benchmark(num_tokens, version): + topk = 9 + hidden_size = 4096 + dtype = torch.bfloat16 + scaling_factor = 0.3 + + x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + + # Warmup + for _ in range(3): + if version == "baseline": + compute_sum_scaled_baseline(x, out, scaling_factor) + elif version == "compiled": + compute_sum_scaled_compiled(x, out, scaling_factor) + else: + moe_sum_reduce(x, out, scaling_factor) + + # Benchmark + quantiles = [0.5, 0.2, 0.8] + if version == "baseline": + ms, min_ms, max_ms = do_bench( + lambda: compute_sum_scaled_baseline(x, out, scaling_factor), + quantiles=quantiles, + ) + elif version == "compiled": + ms, min_ms, max_ms = do_bench( + lambda: compute_sum_scaled_compiled(x, out, scaling_factor), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = do_bench( + lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +def verify_correctness(num_tokens=1024): + x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16) + scaling_factor = 0.3 + + out_baseline = torch.empty_like(x[:, 0]) + compute_sum_scaled_baseline(x, out_baseline, scaling_factor) + + out_compiled = torch.empty_like(out_baseline) + compute_sum_scaled_compiled(x, out_compiled, scaling_factor) + + out_triton = torch.empty_like(out_baseline) + moe_sum_reduce(x, out_triton, scaling_factor) + + if torch.allclose( + out_baseline, out_compiled, atol=1e-2, rtol=1e-2 + ) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + print( + f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" + ) + print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}") + + +if __name__ == "__main__": + print("Running correctness verification...") + verify_correctness() + + print("\nRunning performance benchmark...") + benchmark = get_benchmark() + benchmark.run( + print_data=True, + # save_path="./configs/benchmark_ops/sum_scaled/" + ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 935de6e571e..5568dd6fc5e 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1155,6 +1155,7 @@ def inplace_fused_experts( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + routed_scaling_factor: Optional[float] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1177,6 +1178,8 @@ def inplace_fused_experts( a1_scale, a2_scale, block_shape, + False, + routed_scaling_factor, ) @@ -1200,6 +1203,7 @@ def inplace_fused_experts_fake( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + routed_scaling_factor: Optional[float] = None, ) -> None: pass @@ -1233,6 +1237,7 @@ def outplace_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -1256,6 +1261,7 @@ def outplace_fused_experts( a2_scale, block_shape, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) @@ -1280,6 +1286,7 @@ def outplace_fused_experts_fake( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1314,7 +1321,9 @@ def fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ): + if inplace: assert not no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( @@ -1337,6 +1346,7 @@ def fused_experts( a1_scale, a2_scale, block_shape, + routed_scaling_factor, ) return hidden_states else: @@ -1361,9 +1371,102 @@ def fused_experts( a2_scale, block_shape, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + ) + + +# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py +@triton.jit +def _moe_sum_reduce_kernel( + input_ptr, + input_stride_0, + input_stride_1, + input_stride_2, + output_ptr, + output_stride_0, + output_stride_1, + token_num: int, + topk_num: int, + hidden_dim: int, + routed_scaling_factor: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DIM: tl.constexpr, + NUM_STAGE: tl.constexpr, +): + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) + + token_block_id = tl.program_id(0) + dim_block_id = tl.program_id(1) + + token_start = token_block_id * BLOCK_M + token_end = min((token_block_id + 1) * BLOCK_M, token_num) + + dim_start = dim_block_id * BLOCK_DIM + dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) + + offs_dim = dim_start + tl.arange(0, BLOCK_DIM) + + for token_index in range(token_start, token_end): + accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) + input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim + for i in tl.range(0, topk_num, num_stages=NUM_STAGE): + tmp = tl.load( + input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 + ) + accumulator += tmp + accumulator = accumulator * routed_scaling_factor + store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim + tl.store( + store_t_ptr, + accumulator.to(input_ptr.dtype.element_ty), + mask=offs_dim < dim_end, ) +def moe_sum_reduce_triton( + input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float +): + assert input.is_contiguous() + assert output.is_contiguous() + + token_num, topk_num, hidden_dim = input.shape + assert output.shape[0] == token_num and output.shape[1] == hidden_dim + + BLOCK_M = 1 + BLOCK_DIM = 2048 + NUM_STAGE = 1 + num_warps = 8 + + grid = ( + triton.cdiv(token_num, BLOCK_M), + triton.cdiv(hidden_dim, BLOCK_DIM), + ) + + _moe_sum_reduce_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + token_num=token_num, + topk_num=topk_num, + hidden_dim=hidden_dim, + routed_scaling_factor=routed_scaling_factor, + BLOCK_M=BLOCK_M, + BLOCK_DIM=BLOCK_DIM, + NUM_STAGE=NUM_STAGE, + num_warps=num_warps, + ) + return + + +@torch.compile +def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): + torch.sum(x, dim=1, out=out) + out.mul_(routed_scaling_factor) + + def fused_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1386,6 +1489,7 @@ def fused_experts_impl( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ): padded_size = padding_size if ( @@ -1562,6 +1666,9 @@ def fused_experts_impl( block_shape=block_shape, ) + if routed_scaling_factor is None: + routed_scaling_factor = 1.0 + if no_combine: pass elif _is_hip: @@ -1570,20 +1677,28 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], ) else: - if topk_ids.shape[1] == 1: + if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0: pass # we write directly into out_hidden_states - elif topk_ids.shape[1] == 2: + elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0: torch.add( intermediate_cache3[:, 0], intermediate_cache3[:, 1], out=out_hidden_states[begin_chunk_idx:end_chunk_idx], ).squeeze(dim=1) - elif topk_ids.shape[1] > 2: - torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. + if tokens_in_chunk <= 32: + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) + else: + moe_sum_reduce_triton( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + routed_scaling_factor, + ) return out_hidden_states @@ -1695,4 +1810,5 @@ def fused_moe( a2_scale=a2_scale, block_shape=block_shape, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 981780f421f..f1d32398239 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -225,6 +225,7 @@ def forward_cuda( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) def forward_cpu( diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index c9bb5ae97eb..f3885759558 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -411,4 +411,5 @@ def apply( a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index abdea4e7491..0aaa3a508c8 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -317,6 +317,7 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, apply_router_weight_on_input=apply_router_weight_on_input, + routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b561b26604d..c779f1f1d39 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1030,6 +1030,7 @@ def apply( a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) def maybe_apply_hip_fused_experts( diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index d7d836089d9..4be00f8a3b0 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -388,6 +388,7 @@ def apply( w2_zp=layer.w2_qzeros if has_zp else None, block_shape=[0, layer.group_size], no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) @staticmethod diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 8255b97abc6..b2e606f4d2e 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -328,4 +328,5 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index a906e567dbb..a973403ca6a 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -268,4 +268,5 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 895ef648b5b..7970d3503e5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -346,7 +346,7 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - final_hidden_states *= self.routed_scaling_factor + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: