diff --git a/python/sglang/jit_kernel/activation.py b/python/sglang/jit_kernel/activation.py index 5756c2f38287..89b28bbf6d92 100644 --- a/python/sglang/jit_kernel/activation.py +++ b/python/sglang/jit_kernel/activation.py @@ -38,6 +38,10 @@ def _jit_activation_module(dtype: torch.dtype) -> Module: extra_cuda_cflags=_fast_math_flags(), cuda_wrappers=[ ("run_activation", f"ActivationKernel<{args}>::run_activation"), + ( + "run_activation_filtered", + f"ActivationKernel<{args}>::run_activation_filtered", + ), ], ) @@ -56,30 +60,68 @@ def _run_activation_inplace( module.run_activation(input_2d, out_2d, op_name) +@register_custom_op(mutates_args=["out"]) +def _run_activation_filtered_inplace( + op_name: str, + input: torch.Tensor, + out: torch.Tensor, + expert_ids: torch.Tensor, + expert_step: int, +) -> None: + hidden_size = input.shape[-1] // 2 + module = _jit_activation_module(input.dtype) + input_2d = input.view(-1, hidden_size * 2) + out_2d = out.view(-1, hidden_size) + module.run_activation_filtered(input_2d, out_2d, expert_ids, expert_step, op_name) + + def run_activation( - op_name: str, input: torch.Tensor, out: Optional[torch.Tensor] + op_name: str, + input: torch.Tensor, + out: Optional[torch.Tensor], + expert_ids: Optional[torch.Tensor] = None, + expert_step: int = 1, ) -> torch.Tensor: + """Apply ``op_name`` activation followed by element-wise multiplication. + + When ``expert_ids`` is provided, output rows are skipped for tokens whose + routed expert id is ``-1``. ``expert_step`` is 1 for per-token routing and + ``BLOCK_SIZE_M`` for sorted/TMA routing — i.e. ``expert_ids[token_id // + expert_step]`` is consulted before computing each row. + """ assert op_name in SUPPORTED_ACTIVATIONS, f"Unsupported activation: {op_name}" hidden_size = input.shape[-1] // 2 if out is None: out = input.new_empty(*input.shape[:-1], hidden_size) - _run_activation_inplace(op_name, input, out) + if expert_ids is None: + _run_activation_inplace(op_name, input, out) + else: + _run_activation_filtered_inplace(op_name, input, out, expert_ids, expert_step) return out def silu_and_mul( - input: torch.Tensor, out: Optional[torch.Tensor] = None + input: torch.Tensor, + out: Optional[torch.Tensor] = None, + expert_ids: Optional[torch.Tensor] = None, + expert_step: int = 1, ) -> torch.Tensor: - return run_activation("silu", input, out) + return run_activation("silu", input, out, expert_ids, expert_step) def gelu_and_mul( - input: torch.Tensor, out: Optional[torch.Tensor] = None + input: torch.Tensor, + out: Optional[torch.Tensor] = None, + expert_ids: Optional[torch.Tensor] = None, + expert_step: int = 1, ) -> torch.Tensor: - return run_activation("gelu", input, out) + return run_activation("gelu", input, out, expert_ids, expert_step) def gelu_tanh_and_mul( - input: torch.Tensor, out: Optional[torch.Tensor] = None + input: torch.Tensor, + out: Optional[torch.Tensor] = None, + expert_ids: Optional[torch.Tensor] = None, + expert_step: int = 1, ) -> torch.Tensor: - return run_activation("gelu_tanh", input, out) + return run_activation("gelu_tanh", input, out, expert_ids, expert_step) diff --git a/python/sglang/jit_kernel/benchmark/bench_activation.py b/python/sglang/jit_kernel/benchmark/bench_activation.py index 2caac5551c8e..3f0ba2f6c85e 100644 --- a/python/sglang/jit_kernel/benchmark/bench_activation.py +++ b/python/sglang/jit_kernel/benchmark/bench_activation.py @@ -82,5 +82,76 @@ def f(): return run_benchmark(f, scale=NUM_LAYERS) +FILTER_OPS = ["silu", "gelu"] +FILTER_BS = get_benchmark_range( + full_range=[64, 256, 1024, 4096, 16384], ci_range=[1024] +) +FILTER_DIMS = get_benchmark_range(full_range=[1024, 4096, 8192], ci_range=[4096]) +FILTER_RATIOS = get_benchmark_range(full_range=[0.0, 0.25, 0.5], ci_range=[0.25]) +FILTER_CONFIGS = list( + itertools.product(FILTER_OPS, FILTER_DIMS, FILTER_BS, FILTER_RATIOS) +) + + +def _make_expert_ids(num_tokens: int, skip_ratio: float) -> torch.Tensor: + expert_ids = torch.randint( + low=0, high=8, size=(num_tokens,), dtype=torch.int32, device=DEFAULT_DEVICE + ) + if skip_ratio > 0: + skip = torch.rand(num_tokens, device=DEFAULT_DEVICE) < skip_ratio + expert_ids[skip] = -1 + return expert_ids + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["op_name", "dim", "batch_size", "skip_ratio"], + x_vals=FILTER_CONFIGS, + line_arg="provider", + line_vals=["unfiltered", "filtered"], + line_names=["JIT (no filter_expert)", "JIT (with expert_ids)"], + styles=[("blue", "--"), ("orange", "-")], + ylabel="us", + plot_name="activation-filter-expert", + args={}, + ) +) +def benchmark_filter( + op_name: str, dim: int, batch_size: int, skip_ratio: float, provider: str +): + x = torch.randn( + NUM_LAYERS, + batch_size, + 2 * dim, + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + out = torch.empty( + NUM_LAYERS, + batch_size, + dim, + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + expert_ids = _make_expert_ids(batch_size, skip_ratio) + + jit_fn = silu_and_mul_jit if op_name == "silu" else gelu_and_mul_jit + + if provider == "unfiltered": + + def f(): + for i in range(NUM_LAYERS): + jit_fn(x[i], out[i]) + + else: # filtered + + def f(): + for i in range(NUM_LAYERS): + jit_fn(x[i], out[i], expert_ids=expert_ids, expert_step=1) + + return run_benchmark(f, scale=NUM_LAYERS) + + if __name__ == "__main__": benchmark.run(print_data=True) + benchmark_filter.run(print_data=True) diff --git a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh index 759703b85163..1396eb5a992e 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh @@ -44,9 +44,14 @@ struct ActivationParams { void* __restrict__ out; uint32_t hidden_dim; uint32_t num_tokens; + // Optional MoE expert filtering: when expert_ids != nullptr, a token is + // skipped if expert_ids[token_id / expert_step] == -1. expert_step is 1 + // for per-token routing and BLOCK_SIZE_M for sorted/TMA routing. + const int32_t* __restrict__ expert_ids; + uint32_t expert_step; }; -template +template __global__ void act_and_mul_kernel(const __grid_constant__ ActivationParams params) { using namespace device; constexpr auto kVecSize = kMaxVecBytes / sizeof(T); @@ -56,6 +61,9 @@ __global__ void act_and_mul_kernel(const __grid_constant__ ActivationParams para const auto token_id = tid / num_vecs; if (token_id >= params.num_tokens) return; + if constexpr (kFilterExpert) { + if (params.expert_ids[token_id / params.expert_step] == -1) return; + } const auto offset = tid % num_vecs; const auto input_offset = token_id * (num_vecs * 2) + offset; const auto output_offset = tid; @@ -78,11 +86,33 @@ struct ActivationKernel { static constexpr auto kVecSize = device::kMaxVecBytes / sizeof(T); static constexpr auto kBlockSize = 256u; - template - static constexpr auto activation_kernel = act_and_mul_kernel; + template + static constexpr auto activation_kernel = act_and_mul_kernel; static_assert(device::kMaxVecBytes % sizeof(T) == 0, "unsupported data type"); - static void run_activation(const tvm::ffi::TensorView input, const tvm::ffi::TensorView out, std::string type) { + + template + static auto select_kernel(const std::string& type) + -> decltype(activation_kernel) { + using namespace host; + if (type == "silu") { + return activation_kernel; + } else if (type == "gelu") { + return activation_kernel; + } else if (type == "gelu_tanh") { + return activation_kernel; + } else { + Panic("unsupported activation type: ", type); + } + return nullptr; + } + + static void launch( + const tvm::ffi::TensorView& input, + const tvm::ffi::TensorView& out, + const std::string& type, + const int32_t* expert_ids, + uint32_t expert_step) { using namespace host; auto N = SymbolicSize{"num_tokens"}; @@ -106,18 +136,6 @@ struct ActivationKernel { if (num_tokens == 0) return; RuntimeCheck(hidden_size * 2 == D_in.unwrap(), "invalid activation dimension"); RuntimeCheck(hidden_size % kVecSize == 0, "hidden size must be divisible by vector size"); - const auto kernel = [&]() -> decltype(activation_kernel) { - if (type == "silu") { - return activation_kernel; - } else if (type == "gelu") { - return activation_kernel; - } else if (type == "gelu_tanh") { - return activation_kernel; - } else { - Panic("unsupported activation type: ", type); - } - return nullptr; - }(); // only get once to avoid overhead const auto num_total_items = num_tokens * (hidden_size / kVecSize); RuntimeCheck(num_total_items <= std::numeric_limits::max(), "too many items for 32-bit indexing"); @@ -127,8 +145,33 @@ struct ActivationKernel { .out = out.data_ptr(), .hidden_dim = hidden_size, .num_tokens = num_tokens, + .expert_ids = expert_ids, + .expert_step = expert_step, }; - LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params); + if (expert_ids != nullptr) { + RuntimeCheck(expert_step > 0, "expert_step must be positive"); + const auto kernel = select_kernel(type); + LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params); + } else { + const auto kernel = select_kernel(type); + LaunchKernel(num_blocks, kBlockSize, device).enable_pdl(kUsePDL)(kernel, params); + } + } + + static void run_activation(const tvm::ffi::TensorView input, const tvm::ffi::TensorView out, std::string type) { + launch(input, out, type, /*expert_ids=*/nullptr, /*expert_step=*/1); + } + + static void run_activation_filtered( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView out, + const tvm::ffi::TensorView expert_ids, + int64_t expert_step, + std::string type) { + using namespace host; + RuntimeCheck(is_type(expert_ids.dtype()), "expert_ids must have dtype int32"); + RuntimeCheck(expert_step >= 1, "expert_step must be positive"); + launch(input, out, type, static_cast(expert_ids.data_ptr()), static_cast(expert_step)); } }; diff --git a/python/sglang/jit_kernel/tests/test_activation.py b/python/sglang/jit_kernel/tests/test_activation.py index c43bc41bfcd7..4ad39b269f14 100644 --- a/python/sglang/jit_kernel/tests/test_activation.py +++ b/python/sglang/jit_kernel/tests/test_activation.py @@ -75,5 +75,85 @@ def test_activation_out_param( torch.testing.assert_close(out, expected, atol=atol, rtol=rtol) +FILTER_SHAPES = get_ci_test_range( + full_range=[(83, 1024), (256, 2048), (1024, 4096)], + ci_range=[(83, 1024)], +) +EXPERT_STEPS = [1, 16] + + +@pytest.mark.parametrize("op_name", OPS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", FILTER_SHAPES) +@pytest.mark.parametrize("expert_step", EXPERT_STEPS) +def test_activation_filter_expert( + op_name: str, + dtype: torch.dtype, + shape: tuple[int, int], + expert_step: int, +) -> None: + """expert_ids[token // expert_step] == -1 must leave the output row untouched.""" + num_tokens = shape[0] + x = torch.randn(shape, dtype=dtype, device="cuda") + # Pre-fill out with a sentinel so we can detect untouched rows. + sentinel = float("nan") + out = torch.full( + shape[:-1] + (shape[-1] // 2,), + sentinel, + dtype=dtype, + device="cuda", + ) + + num_groups = (num_tokens + expert_step - 1) // expert_step + expert_ids = torch.randint( + low=0, high=8, size=(num_groups,), dtype=torch.int32, device="cuda" + ) + skip_mask = torch.rand(num_groups, device="cuda") < 0.4 + expert_ids[skip_mask] = -1 + + result = run_activation(op_name, x, out, expert_ids, expert_step) + assert result is out + + token_skip = skip_mask[torch.arange(num_tokens, device="cuda") // expert_step] + expected = _reference(op_name, x) + atol, rtol = _tolerances(dtype) + + kept = ~token_skip + if kept.any(): + torch.testing.assert_close(out[kept], expected[kept], atol=atol, rtol=rtol) + if token_skip.any(): + assert torch.isnan( + out[token_skip] + ).all(), "filter_expert kernel touched rows whose expert_id is -1" + + +@pytest.mark.parametrize("op_name", OPS) +def test_activation_filter_expert_all_skipped(op_name: str) -> None: + """If every expert id is -1, the output must be left entirely untouched.""" + shape = (32, 512) + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + out = torch.full( + shape[:-1] + (shape[-1] // 2,), + float("nan"), + dtype=torch.bfloat16, + device="cuda", + ) + expert_ids = torch.full((shape[0],), -1, dtype=torch.int32, device="cuda") + run_activation(op_name, x, out, expert_ids, 1) + assert torch.isnan(out).all() + + +@pytest.mark.parametrize("op_name", OPS) +def test_activation_filter_expert_none_skipped(op_name: str) -> None: + """No -1 in expert_ids must yield bit-identical output to the unfiltered path.""" + shape = (64, 512) + dtype = torch.bfloat16 + x = torch.randn(shape, dtype=dtype, device="cuda") + expert_ids = torch.zeros((shape[0],), dtype=torch.int32, device="cuda") + out_filtered = run_activation(op_name, x, None, expert_ids, 1) + out_unfiltered = run_activation(op_name, x, None) + torch.testing.assert_close(out_filtered, out_unfiltered, atol=0.0, rtol=0.0) + + if __name__ == "__main__": sys.exit(pytest.main([__file__, "-v", "-s"])) diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py index 8aa577c97662..fff837fe49cc 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py @@ -28,7 +28,6 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config from .fused_moe_triton_kernels import ( - act_and_mul_triton, invoke_fused_moe_kernel, moe_sum_reduce_triton, support_tensor_descriptor, @@ -530,18 +529,17 @@ def _fused_moe_kernel_sequence( intermediate_cache1.view(-1, N), gemm1_limit ) elif _is_cuda or _is_hip or _is_xpu: - if not filter_expert: - silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) - else: - act_and_mul_triton( + if filter_expert and _is_cuda: + # HIP/XPU fall through to the unfiltered path: the down kernel + # zeros filtered rows without reading their input. + silu_and_mul( intermediate_cache1.view(-1, N), intermediate_cache2, - config, - topk_ids, - expert_ids, - down_moe_use_tma, - activation, + expert_ids=(expert_ids if down_moe_use_tma else topk_ids.view(-1)), + expert_step=(config["BLOCK_SIZE_M"] if down_moe_use_tma else 1), ) + else: + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) elif _is_musa: intermediate_cache2 = _silu_and_mul_musa(intermediate_cache1.view(-1, N)) else: @@ -558,18 +556,15 @@ def _fused_moe_kernel_sequence( assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" if _is_cuda or _is_hip: - if not filter_expert: - gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) - else: - act_and_mul_triton( + if filter_expert and _is_cuda: + gelu_and_mul( intermediate_cache1.view(-1, N), intermediate_cache2, - config, - topk_ids, - expert_ids, - down_moe_use_tma, - activation, + expert_ids=(expert_ids if down_moe_use_tma else topk_ids.view(-1)), + expert_step=(config["BLOCK_SIZE_M"] if down_moe_use_tma else 1), ) + else: + gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: if _has_vllm_ops: vllm_ops.gelu_and_mul( diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py index 323d263bf429..d02a29762ea9 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py @@ -930,112 +930,6 @@ def invoke_fused_moe_kernel( ) -@triton.jit -def tanh(x): - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def _apply_activation(x, ACTIVATION_TYPE: tl.constexpr): - """ - Apply activation function based on compile-time constant. - - Args: - x: Input tensor (converted to float32 inside) - ACTIVATION_TYPE: Compile-time constant string ("silu" or "gelu") - - Returns: - Activated output in the same dtype as input - """ - x = x.to(tl.float32) - if ACTIVATION_TYPE == "silu": - return x * tl.sigmoid(x) - elif ACTIVATION_TYPE == "gelu": - kAlpha = 0.7978845608028654 - return 0.5 * x * (1 + tanh(kAlpha * (x + 0.044715 * x * x * x))) - else: - raise ValueError(f"Unsupported activation: {ACTIVATION_TYPE}") - - -@triton.jit -def act_and_mul_kernel( - gateup_output, - down_input, - hidden_size, - expert_ids_ptr, - expert_step: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - ACTIVATION_TYPE: tl.constexpr, -): - """ - Unified activation and multiply kernel that handles both sorted and unsorted routing, - and both SiLU and GELU activations using compile-time constants. - """ - InDtype = gateup_output.dtype.element_ty - OutDtype = down_input.dtype.element_ty - - half_hidden_size = hidden_size // 2 - pid = tl.program_id(0) - - expert_id = tl.load(expert_ids_ptr + pid // expert_step) - - if expert_id == -1: - return - - gateup_output_ptr = gateup_output + pid * hidden_size - down_input_ptr = down_input + pid * half_hidden_size - gate_output_ptr = gateup_output_ptr - up_output_ptr = gateup_output_ptr + half_hidden_size - - for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) - mask = offset < half_hidden_size - - gate_output = tl.load(gate_output_ptr + offset, mask=mask) - up_output = tl.load(up_output_ptr + offset, mask=mask) - - gate_output_activated = _apply_activation(gate_output, ACTIVATION_TYPE) - gate_output_activated = gate_output_activated.to(InDtype) - - act_mul_output = gate_output_activated * up_output - act_mul_output = act_mul_output.to(OutDtype) - tl.store(down_input_ptr + offset, act_mul_output, mask=mask) - - -def act_and_mul_triton( - gateup_output: torch.Tensor, - down_input: torch.Tensor, - config: Dict[str, Any], - topk_ids: Optional[torch.Tensor] = None, - expert_ids: Optional[torch.Tensor] = None, - down_moe_use_tma: bool = False, - activation: str = "silu", -) -> None: - """ - Args: - gateup_output: Input tensor containing gate and up outputs concatenated - down_input: Output tensor for the result - config: Configuration dictionary with BLOCK_SIZE_M and BLOCK_SIZE_N - topk_ids: Expert IDs for unsorted routing (used when down_moe_use_tma=False) - expert_ids: Expert IDs for sorted routing (used when down_moe_use_tma=True) - down_moe_use_tma: Whether to use sorted routing layout - activation: Activation type ("silu" or "gelu") - """ - grid = (down_input.shape[0],) - hidden_size = gateup_output.shape[1] - expert_ids_row = topk_ids.view(-1) if not down_moe_use_tma else expert_ids - expert_step = 1 if not down_moe_use_tma else config["BLOCK_SIZE_M"] - act_and_mul_kernel[grid]( - gateup_output, - down_input, - hidden_size, - expert_ids_row, - expert_step, - BLOCK_SIZE=512, - ACTIVATION_TYPE=activation, - ) - - # _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py @triton.jit def _moe_sum_reduce_kernel(