diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index 15ce917166c4..d63943f75106 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -9,11 +9,13 @@ FlexCtx, FnSpecs, FusedActivation, + GatherIndx, PrecisionConfig, + RoutingData, + ScatterIndx, matmul_ogs, ) from triton_kernels.numerics import InFlexData -from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx from triton_kernels.swiglu import swiglu_fn from sglang.srt.utils import is_cuda @@ -295,9 +297,8 @@ def triton_kernel_fused_experts_with_bias( w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex)) act = FusedActivation( - FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), (gemm1_alpha, gemm1_clamp_limit), - 2, ) intermediate_cache = torch.empty( diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py b/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py index a90add0faaa8..258761e505de 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py @@ -19,8 +19,12 @@ from sglang.srt.layers.moe.utils import MoeRunnerBackend if TYPE_CHECKING: - from triton_kernels.matmul_ogs import PrecisionConfig - from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx + from triton_kernels.matmul_ogs import ( + GatherIndx, + PrecisionConfig, + RoutingData, + ScatterIndx, + ) from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardCombineInput, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 2eba4f094a3a..02e88da0c759 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -32,7 +32,50 @@ import torch.nn.functional as F try: - from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing + from triton_kernels.matmul_ogs import GatherIndx, RoutingData, ScatterIndx + from triton_kernels.tensor import make_ragged_tensor_metadata + from triton_kernels.topk import topk as triton_kernels_topk + + def routing( + logits, + n_expts_act, + sm_first=False, + expt_indx=None, + simulated_ep=1, + n_rows=None, + ): + if simulated_ep != 1: + raise NotImplementedError( + "simulated_ep routing is not supported with triton_kernels 3.6.0" + ) + + if sm_first: + logits = torch.softmax(logits, dim=-1) + + sparse_logits = triton_kernels_topk( + logits, + n_expts_act, + apply_softmax=not sm_first, + y_indx=expt_indx, + n_rows=n_rows, + ) + dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx + combine_indx = sparse_logits.mask_metadata.col_sorted_indx + ragged_metadata = make_ragged_tensor_metadata( + sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0] + ) + gate_scal = sparse_logits.vals.flatten()[combine_indx] + routing_data = RoutingData( + gate_scal, + ragged_metadata.slice_sizes, + logits.shape[-1], + n_expts_act, + ragged_metadata, + ) + gather_indx = GatherIndx(combine_indx, dispatch_indx) + scatter_indx = ScatterIndx(dispatch_indx, combine_indx) + return routing_data, gather_indx, scatter_indx + except ImportError: pass diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index d5ad4d403493..c553723c8205 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -113,6 +113,7 @@ def _get_flashinfer_mxfp4_device_permute_indices( _is_hip = is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_shuffle_moe_mxfp4 = is_gfx95_supported() +_sm120_mxfp4_min_warps_patched = False if _is_hip: # import aiter @@ -128,6 +129,49 @@ def _get_flashinfer_mxfp4_device_permute_indices( dynamic_mxfp4_quant = e8m0_shuffle = err +def _patch_sm120_mxfp4_min_warps(): + global _sm120_mxfp4_min_warps_patched + if _sm120_mxfp4_min_warps_patched: + return + + import inspect + + from triton_kernels.matmul_ogs_details.opt_flags_details import opt_flags_nvidia + from triton_kernels.tensor import get_layout + from triton_kernels.tensor_details.layout import StridedLayout + + compute_num_warps = opt_flags_nvidia.compute_num_warps + params = inspect.signature(compute_num_warps).parameters + + if "is_persistent" in params and not getattr( + compute_num_warps, "_sglang_sm120_mxfp4_patch", False + ): + + def _compute_num_warps_sm120_mxfp4( + block_m, block_n, is_persistent, precision_config + ): + selected_num_warps = compute_num_warps( + block_m, block_n, is_persistent, precision_config + ) + weight_scale = getattr(precision_config, "weight_scale", None) + weight_scale_layout = get_layout(weight_scale) + if ( + not is_persistent + and weight_scale is not None + and ( + weight_scale_layout is StridedLayout + or isinstance(weight_scale_layout, StridedLayout) + ) + ): + return max(selected_num_warps, 4) + return selected_num_warps + + _compute_num_warps_sm120_mxfp4._sglang_sm120_mxfp4_patch = True + opt_flags_nvidia.compute_num_warps = _compute_num_warps_sm120_mxfp4 + + _sm120_mxfp4_min_warps_patched = True + + def _swizzle_mxfp4(quant_tensor, scale, num_warps): """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" import triton_kernels.matmul_ogs_details.opt_flags as opt_flags @@ -137,8 +181,8 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): if is_sm120_supported(): # SM120 desktop Blackwell does not support the persistent/TMA MXFP4 path. - # This MXFP4 path uses StridedLayout and the non-persistent kernel with - # block_k=128 so the selected tile stays within the per-block shared-memory budget. + # This MXFP4 path uses StridedLayout and the non-persistent kernel. + _patch_sm120_mxfp4_min_warps() from triton_kernels.tensor_details.layout import StridedLayout value_layout = StridedLayout @@ -147,7 +191,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): scale_layout_opts = {} constraints = { "is_persistent": False, - "block_k": 128, "num_stages": 1, } opt_flags.update_opt_flags_constraints(constraints) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 2d5d5898155d..2a4ec55ac019 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3602,7 +3602,11 @@ async def wait_for_zero(self): @lru_cache(maxsize=1) def is_triton_kernels_available() -> bool: - return importlib.util.find_spec("triton_kernels") is not None + triton_kernels_spec = importlib.util.find_spec("triton_kernels") + ragged_metadata_spec = importlib.util.find_spec( + "triton_kernels.tensor_details.ragged_tensor" + ) + return triton_kernels_spec is not None and ragged_metadata_spec is not None @lru_cache(maxsize=1) diff --git a/scripts/ci/cuda/ci_install_dependency.sh b/scripts/ci/cuda/ci_install_dependency.sh index d53e52c0b676..66a37310de77 100755 --- a/scripts/ci/cuda/ci_install_dependency.sh +++ b/scripts/ci/cuda/ci_install_dependency.sh @@ -285,11 +285,15 @@ install_sglang_kernel() { $PIP_CMD install "torch==${TORCH_VER}" "torchaudio==${TORCHAUDIO_VER}" "torchvision==${TORCHVISION_VER}" --index-url "https://download.pytorch.org/whl/${CU_VERSION}" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX fi - # install_sglang above pulls sglang-kernel from PyPI, whose default wheel - # tracks one CUDA version (currently cu130). Force-reinstall from the - # CU_VERSION-matched sglang wheel index so runners on a different CUDA - # (e.g. h20 / cu129) get a wheel linked against the right libnvrtc. - $PIP_CMD install "sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT}" --index-url "https://docs.sglang.ai/whl/${CU_VERSION}/" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX + if [ "${CUSTOM_BUILD_SGL_KERNEL:-}" != "true" ]; then + # install_sglang above pulls sglang-kernel from PyPI, whose default wheel + # tracks one CUDA version (currently cu130). Force-reinstall from the + # CU_VERSION-matched sglang wheel index so runners on a different CUDA + # (e.g. h20 / cu129) get a wheel linked against the right libnvrtc. + $PIP_CMD install "sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT}" --index-url "https://docs.sglang.ai/whl/${CU_VERSION}/" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX + else + echo "CUSTOM_BUILD_SGL_KERNEL=true: keeping freshly built sgl-kernel wheel." + fi mark_step_done "${FUNCNAME[0]}" } diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index bbacf6dc4021..37ba3a71f786 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -74,7 +74,7 @@ FetchContent_Populate(repo-fmt) FetchContent_Declare( repo-triton GIT_REPOSITORY "https://github.com/triton-lang/triton" - GIT_TAG v3.5.1 + GIT_TAG v3.6.0 GIT_SHALLOW OFF ) FetchContent_Populate(repo-triton)