[MoE] Deprecate act_and_mul_triton; fold filter_expert into JIT silu/gelu_and_mul#23707
[MoE] Deprecate act_and_mul_triton; fold filter_expert into JIT silu/gelu_and_mul#23707
Conversation
…gelu_and_mul Replace the Triton act_and_mul_triton kernel with an extension to the JIT CUDA silu_and_mul / gelu_and_mul kernels: they now accept optional expert_ids / expert_step kwargs and skip rows whose routed expert is -1. CUDA filter_expert paths in fused_moe.py route through the new JIT path. HIP keeps using the AOT sgl_kernel silu_and_mul / gelu_and_mul for both filtered and unfiltered cases — the downstream fused_moe down kernel writes zeros for filtered experts before reading their input rows (fused_moe_triton_kernels.py:192-208), so writing a real activation to those rows is harmless. This avoids exercising the JIT activation kernel on AMD for the first time in this PR. Tests: extend test_activation.py with filter_expert coverage (per-token and sorted/TMA layouts, all-skipped, none-skipped); 348 tests pass. Benchmark: bench_activation.py adds an unfiltered-vs-filtered comparison that confirms the expert_ids skip path costs <0.3μs and scales work linearly with skip ratio. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces expert-based filtering to the JIT activation kernels, allowing computation to be skipped for tokens based on expert IDs. The implementation replaces the previous Triton-based filtered activation with a unified CUDA kernel and updates the MoE runner to utilize this new path. Feedback was provided to include safety checks for the expert_ids tensor, specifically verifying its device compatibility and dimensionality to prevent potential runtime crashes or illegal memory access.
| using namespace host; | ||
| RuntimeCheck(is_type<int32_t>(expert_ids.dtype()), "expert_ids must have dtype int32"); | ||
| RuntimeCheck(expert_step >= 1, "expert_step must be positive"); | ||
| launch(input, out, type, static_cast<const int32_t*>(expert_ids.data_ptr()), static_cast<uint32_t>(expert_step)); |
There was a problem hiding this comment.
The expert_ids tensor should be verified to be on the same device as the input and out tensors. Accessing a CPU tensor's data pointer from a CUDA kernel will lead to a segmentation fault or illegal memory access. Additionally, verifying that expert_ids is a 1D tensor ensures the indexing logic in the kernel remains valid.
using namespace host;
RuntimeCheck(is_type<int32_t>(expert_ids.dtype()), "expert_ids must have dtype int32");
RuntimeCheck(expert_ids.device().device_type == input.device().device_type &&
expert_ids.device().device_id == input.device().device_id,
"expert_ids must be on the same device as input");
RuntimeCheck(expert_ids.ndim() == 1, "expert_ids must be a 1D tensor");
RuntimeCheck(expert_step >= 1, "expert_step must be positive");
launch(input, out, type, static_cast<const int32_t*>(expert_ids.data_ptr()), static_cast<uint32_t>(expert_step));
|
/tag-and-rerun-ci |
Remove section banners and trailing comments that restate the next line of code. Keep load-bearing WHY: the HIP/XPU fall-through note in fused_moe.py (downstream zero-write makes it safe), test docstrings, and the sentinel-NaN rationale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…gelu_and_mul (sgl-project#23707) Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Motivation
act_and_mul_triton(infused_moe_triton_kernels.py) duplicatessilu_and_mul/gelu_and_mul. The only difference is that it skips rows whose routed expert id is-1(thefilter_expert=TrueMoE path used under EP). The JIT CUDAsilu_and_mul/gelu_and_mulkernels already exist and are faster — the consolidation removes ~100 lines of Triton and a redundant kernel.Modifications
JIT activation kernel (CUDA)
python/sglang/jit_kernel/csrc/elementwise/activation.cuh: addedexpert_ids(const int32_t*) andexpert_step(uint32_t) toActivationParams; added a compile-timekFilterExperttemplate bool toact_and_mul_kernel(zero overhead when off —if constexprskips the load); exposedrun_activation_filteredhost method.python/sglang/jit_kernel/activation.py:silu_and_mul/gelu_and_mul/gelu_tanh_and_mul/run_activationnow accept optionalexpert_idsandexpert_stepkwargs. Existing call sites are unchanged.MoE call sites (
triton_utils/fused_moe.py)act_and_mul_triton(...)calls withsilu_and_mul/gelu_and_mulpassingexpert_idsandexpert_step.expert_ids(skips filtered rows).sgl_kernel.silu_and_mul/gelu_and_mul(unfiltered). The downstream fused MoE down kernel writes zeros for filtered experts andreturns before reading the activation input (fused_moe_triton_kernels.py:192–208), so writing real output to those rows is harmless. This avoids exposing the JIT activation kernel to AMD users for the first time in this PR.Removed
act_and_mul_kernel(Triton) andact_and_mul_tritonwrapper, plus the now-unused_apply_activationandtanhhelpers infused_moe_triton_kernels.py.Tests (
python/sglang/jit_kernel/tests/test_activation.py)test_activation_filter_expertparametrized over op × dtype × shape ×expert_step ∈ {1, 16}(per-token and sorted/TMA routing).test_activation_filter_expert_all_skippedandtest_activation_filter_expert_none_skippededge cases (the latter asserts bit-exact equality with the unfiltered path).Benchmark (
python/sglang/jit_kernel/benchmark/bench_activation.py)benchmark_filtercomparing the filtered JIT path vs the unfiltered baseline across batch × dim × skip ratio.Accuracy Tests
pytest python/sglang/jit_kernel/tests/test_activation.py— 348 passed in 23s (135 pre-existing + 213 new filter_expert).The
none_skippedtest asserts bit-exact equality between the filtered kernel (with all expert_ids = 0) and the unfiltered kernel, so the filter machinery does not perturb the math.Speed Tests and Profiling
The new filter benchmark (
bench_activation.py::benchmark_filter, bf16):skip_ratio=0: filter overhead is bounded (≤ ~0.3μs absolute, negligible at large shapes — e.g. 213.89 → 213.63μs).skip_ratio>0: work scales linearly with skipped rows, exactly as expected.Compared to the deleted Triton
act_and_mul_tritonkernel (measured against a local copy before removal), the JIT path is ~2–3× faster in the launch-bound regime (small/medium batches, the regime MoE expert tiles hit during decode) and within ±5% at HBM-bandwidth-bound shapes.Checklist
Notes for reviewers
pr-test-jit-kernel.ymlworkflow runs only--hw cuda). To avoid making this PR a "first run on AMD" gamble, the HIP filter_expert path falls back to the unfiltered AOTsgl_kernelkernel and accepts the small wasted compute on filtered rows. A follow-up PR registeringtest_activation.pyon AMD CI would let us route HIP filter_expert through the JIT kernel too._apply_activationandtanhfromfused_moe_triton_kernels.py— these had no other callers (verified via grep across the repo).🤖 Generated with Claude Code