From dc321f774607df75a37a3db11d2a3302acdeae9e Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sat, 7 Jun 2025 21:02:17 -0700 Subject: [PATCH] Revert "Fuse routed scaling factor in topk_reduce kernel (#6220)" This reverts commit 515ef4facbc89cd7c093c198386a8817fce856d6. --- .../fused_moe_triton/benchmark_sum_scale.py | 199 ------------------ .../layers/moe/fused_moe_triton/fused_moe.py | 132 +----------- .../srt/layers/moe/fused_moe_triton/layer.py | 1 - .../srt/layers/quantization/blockwise_int8.py | 1 - .../compressed_tensors_moe.py | 1 - python/sglang/srt/layers/quantization/fp8.py | 1 - .../srt/layers/quantization/moe_wna16.py | 1 - .../srt/layers/quantization/w8a8_fp8.py | 1 - .../srt/layers/quantization/w8a8_int8.py | 1 - python/sglang/srt/models/deepseek_v2.py | 2 +- 10 files changed, 9 insertions(+), 331 deletions(-) delete mode 100644 benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py b/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py deleted file mode 100644 index 13ff617448e..00000000000 --- a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch -import triton -import triton.language as tl -from triton.testing import do_bench - - -# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py -@triton.jit -def _moe_sum_reduce_kernel( - input_ptr, - input_stride_0, - input_stride_1, - input_stride_2, - output_ptr, - output_stride_0, - output_stride_1, - token_num: int, - topk_num: int, - hidden_dim: int, - routed_scaling_factor: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DIM: tl.constexpr, - NUM_STAGE: tl.constexpr, -): - input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) - input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) - output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) - - token_block_id = tl.program_id(0) - dim_block_id = tl.program_id(1) - - token_start = token_block_id * BLOCK_M - token_end = min((token_block_id + 1) * BLOCK_M, token_num) - - dim_start = dim_block_id * BLOCK_DIM - dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) - - offs_dim = dim_start + tl.arange(0, BLOCK_DIM) - - for token_index in range(token_start, token_end): - accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) - input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim - for i in tl.range(0, topk_num, num_stages=NUM_STAGE): - tmp = tl.load( - input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 - ) - accumulator += tmp - accumulator = accumulator * routed_scaling_factor - store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim - tl.store( - store_t_ptr, - accumulator.to(input_ptr.dtype.element_ty), - mask=offs_dim < dim_end, - ) - - -def moe_sum_reduce( - input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float -): - assert input.is_contiguous() - assert output.is_contiguous() - - token_num, topk_num, hidden_dim = input.shape - assert output.shape[0] == token_num and output.shape[1] == hidden_dim - - BLOCK_M = 1 - BLOCK_DIM = 2048 - NUM_STAGE = 1 - num_warps = 8 - - grid = ( - triton.cdiv(token_num, BLOCK_M), - triton.cdiv(hidden_dim, BLOCK_DIM), - ) - - _moe_sum_reduce_kernel[grid]( - input, - *input.stride(), - output, - *output.stride(), - token_num=token_num, - topk_num=topk_num, - hidden_dim=hidden_dim, - routed_scaling_factor=routed_scaling_factor, - BLOCK_M=BLOCK_M, - BLOCK_DIM=BLOCK_DIM, - NUM_STAGE=NUM_STAGE, - num_warps=num_warps, - ) - return - - -def compute_sum_scaled_baseline( - x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float -) -> torch.Tensor: - torch.sum(x, dim=1, out=out) - out.mul_(routed_scaling_factor) - return out - - -@torch.compile -def compute_sum_scaled_compiled( - x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float -) -> torch.Tensor: - torch.sum(x * routed_scaling_factor, dim=1, out=out) - return out - - -def get_benchmark(): - num_tokens_range = [2**i for i in range(0, 13)] - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["num_tokens"], - x_vals=num_tokens_range, - line_arg="version", - line_vals=["baseline", "compiled", "triton"], - line_names=["Original", "TorchCompile", "TritonKernel"], - styles=[("blue", "-"), ("green", "-"), ("red", "-")], - ylabel="us", - plot_name="sum_scaled_performance", - args={}, - ) - ) - def benchmark(num_tokens, version): - topk = 9 - hidden_size = 4096 - dtype = torch.bfloat16 - scaling_factor = 0.3 - - x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda") - out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") - - # Warmup - for _ in range(3): - if version == "baseline": - compute_sum_scaled_baseline(x, out, scaling_factor) - elif version == "compiled": - compute_sum_scaled_compiled(x, out, scaling_factor) - else: - moe_sum_reduce(x, out, scaling_factor) - - # Benchmark - quantiles = [0.5, 0.2, 0.8] - if version == "baseline": - ms, min_ms, max_ms = do_bench( - lambda: compute_sum_scaled_baseline(x, out, scaling_factor), - quantiles=quantiles, - ) - elif version == "compiled": - ms, min_ms, max_ms = do_bench( - lambda: compute_sum_scaled_compiled(x, out, scaling_factor), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = do_bench( - lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles - ) - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - return benchmark - - -def verify_correctness(num_tokens=1024): - x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16) - scaling_factor = 0.3 - - out_baseline = torch.empty_like(x[:, 0]) - compute_sum_scaled_baseline(x, out_baseline, scaling_factor) - - out_compiled = torch.empty_like(out_baseline) - compute_sum_scaled_compiled(x, out_compiled, scaling_factor) - - out_triton = torch.empty_like(out_baseline) - moe_sum_reduce(x, out_triton, scaling_factor) - - if torch.allclose( - out_baseline, out_compiled, atol=1e-2, rtol=1e-2 - ) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2): - print("✅ All implementations match") - else: - print("❌ Implementations differ") - print( - f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" - ) - print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}") - - -if __name__ == "__main__": - print("Running correctness verification...") - verify_correctness() - - print("\nRunning performance benchmark...") - benchmark = get_benchmark() - benchmark.run( - print_data=True, - # save_path="./configs/benchmark_ops/sum_scaled/" - ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 5568dd6fc5e..935de6e571e 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1155,7 +1155,6 @@ def inplace_fused_experts( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - routed_scaling_factor: Optional[float] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1178,8 +1177,6 @@ def inplace_fused_experts( a1_scale, a2_scale, block_shape, - False, - routed_scaling_factor, ) @@ -1203,7 +1200,6 @@ def inplace_fused_experts_fake( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - routed_scaling_factor: Optional[float] = None, ) -> None: pass @@ -1237,7 +1233,6 @@ def outplace_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -1261,7 +1256,6 @@ def outplace_fused_experts( a2_scale, block_shape, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) @@ -1286,7 +1280,6 @@ def outplace_fused_experts_fake( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1321,9 +1314,7 @@ def fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ): - if inplace: assert not no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( @@ -1346,7 +1337,6 @@ def fused_experts( a1_scale, a2_scale, block_shape, - routed_scaling_factor, ) return hidden_states else: @@ -1371,102 +1361,9 @@ def fused_experts( a2_scale, block_shape, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - ) - - -# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py -@triton.jit -def _moe_sum_reduce_kernel( - input_ptr, - input_stride_0, - input_stride_1, - input_stride_2, - output_ptr, - output_stride_0, - output_stride_1, - token_num: int, - topk_num: int, - hidden_dim: int, - routed_scaling_factor: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DIM: tl.constexpr, - NUM_STAGE: tl.constexpr, -): - input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) - input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) - output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) - - token_block_id = tl.program_id(0) - dim_block_id = tl.program_id(1) - - token_start = token_block_id * BLOCK_M - token_end = min((token_block_id + 1) * BLOCK_M, token_num) - - dim_start = dim_block_id * BLOCK_DIM - dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) - - offs_dim = dim_start + tl.arange(0, BLOCK_DIM) - - for token_index in range(token_start, token_end): - accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) - input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim - for i in tl.range(0, topk_num, num_stages=NUM_STAGE): - tmp = tl.load( - input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 - ) - accumulator += tmp - accumulator = accumulator * routed_scaling_factor - store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim - tl.store( - store_t_ptr, - accumulator.to(input_ptr.dtype.element_ty), - mask=offs_dim < dim_end, ) -def moe_sum_reduce_triton( - input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float -): - assert input.is_contiguous() - assert output.is_contiguous() - - token_num, topk_num, hidden_dim = input.shape - assert output.shape[0] == token_num and output.shape[1] == hidden_dim - - BLOCK_M = 1 - BLOCK_DIM = 2048 - NUM_STAGE = 1 - num_warps = 8 - - grid = ( - triton.cdiv(token_num, BLOCK_M), - triton.cdiv(hidden_dim, BLOCK_DIM), - ) - - _moe_sum_reduce_kernel[grid]( - input, - *input.stride(), - output, - *output.stride(), - token_num=token_num, - topk_num=topk_num, - hidden_dim=hidden_dim, - routed_scaling_factor=routed_scaling_factor, - BLOCK_M=BLOCK_M, - BLOCK_DIM=BLOCK_DIM, - NUM_STAGE=NUM_STAGE, - num_warps=num_warps, - ) - return - - -@torch.compile -def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): - torch.sum(x, dim=1, out=out) - out.mul_(routed_scaling_factor) - - def fused_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1489,7 +1386,6 @@ def fused_experts_impl( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, ): padded_size = padding_size if ( @@ -1666,9 +1562,6 @@ def fused_experts_impl( block_shape=block_shape, ) - if routed_scaling_factor is None: - routed_scaling_factor = 1.0 - if no_combine: pass elif _is_hip: @@ -1677,28 +1570,20 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], ) else: - if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0: + if topk_ids.shape[1] == 1: pass # we write directly into out_hidden_states - elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0: + elif topk_ids.shape[1] == 2: torch.add( intermediate_cache3[:, 0], intermediate_cache3[:, 1], out=out_hidden_states[begin_chunk_idx:end_chunk_idx], ).squeeze(dim=1) - else: - # According to micro benchmark results, torch.compile can get better performance for small token. - if tokens_in_chunk <= 32: - moe_sum_reduce_torch_compile( - intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx], - routed_scaling_factor, - ) - else: - moe_sum_reduce_triton( - intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx], - routed_scaling_factor, - ) + elif topk_ids.shape[1] > 2: + torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) return out_hidden_states @@ -1810,5 +1695,4 @@ def fused_moe( a2_scale=a2_scale, block_shape=block_shape, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 1804963eb67..adc46a82e16 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -225,7 +225,6 @@ def forward_cuda( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) def forward_cpu( diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index f3885759558..c9bb5ae97eb 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -411,5 +411,4 @@ def apply( a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0aaa3a508c8..abdea4e7491 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -317,7 +317,6 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, apply_router_weight_on_input=apply_router_weight_on_input, - routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index c779f1f1d39..b561b26604d 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1030,7 +1030,6 @@ def apply( a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) def maybe_apply_hip_fused_experts( diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index 4be00f8a3b0..d7d836089d9 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -388,7 +388,6 @@ def apply( w2_zp=layer.w2_qzeros if has_zp else None, block_shape=[0, layer.group_size], no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) @staticmethod diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index b2e606f4d2e..8255b97abc6 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -328,5 +328,4 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index a973403ca6a..a906e567dbb 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -268,5 +268,4 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9a53a1c77ab..83211e8ebd8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -346,7 +346,7 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - + final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: