Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
FlexCtx,
FnSpecs,
FusedActivation,
GatherIndx,
PrecisionConfig,
RoutingData,
ScatterIndx,
matmul_ogs,
)
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn

from sglang.srt.utils import is_cuda
Expand Down Expand Up @@ -295,9 +297,8 @@ def triton_kernel_fused_experts_with_bias(
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))

act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2),
(gemm1_alpha, gemm1_clamp_limit),
2,
)

intermediate_cache = torch.empty(
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/layers/moe/moe_runner/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
from sglang.srt.layers.moe.utils import MoeRunnerBackend

if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.matmul_ogs import (
GatherIndx,
PrecisionConfig,
RoutingData,
ScatterIndx,
)

from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
Expand Down
45 changes: 44 additions & 1 deletion python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,50 @@
import torch.nn.functional as F

try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
from triton_kernels.matmul_ogs import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.tensor import make_ragged_tensor_metadata
from triton_kernels.topk import topk as triton_kernels_topk

def routing(
logits,
n_expts_act,
sm_first=False,
expt_indx=None,
simulated_ep=1,
n_rows=None,
):
if simulated_ep != 1:
raise NotImplementedError(
"simulated_ep routing is not supported with triton_kernels 3.6.0"
)
Comment thread
mmangkad marked this conversation as resolved.

if sm_first:
logits = torch.softmax(logits, dim=-1)

sparse_logits = triton_kernels_topk(
logits,
n_expts_act,
apply_softmax=not sm_first,
y_indx=expt_indx,
n_rows=n_rows,
)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_metadata.slice_sizes,
logits.shape[-1],
n_expts_act,
ragged_metadata,
)
gather_indx = GatherIndx(combine_indx, dispatch_indx)
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_indx, scatter_indx

except ImportError:
pass

Expand Down
49 changes: 46 additions & 3 deletions python/sglang/srt/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _get_flashinfer_mxfp4_device_permute_indices(
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_shuffle_moe_mxfp4 = is_gfx95_supported()
_sm120_mxfp4_min_warps_patched = False

if _is_hip:
# import aiter
Expand All @@ -128,6 +129,49 @@ def _get_flashinfer_mxfp4_device_permute_indices(
dynamic_mxfp4_quant = e8m0_shuffle = err


def _patch_sm120_mxfp4_min_warps():
global _sm120_mxfp4_min_warps_patched
if _sm120_mxfp4_min_warps_patched:
return

import inspect

from triton_kernels.matmul_ogs_details.opt_flags_details import opt_flags_nvidia
from triton_kernels.tensor import get_layout
from triton_kernels.tensor_details.layout import StridedLayout

compute_num_warps = opt_flags_nvidia.compute_num_warps
params = inspect.signature(compute_num_warps).parameters

if "is_persistent" in params and not getattr(
compute_num_warps, "_sglang_sm120_mxfp4_patch", False
):

def _compute_num_warps_sm120_mxfp4(
block_m, block_n, is_persistent, precision_config
):
selected_num_warps = compute_num_warps(
block_m, block_n, is_persistent, precision_config
)
weight_scale = getattr(precision_config, "weight_scale", None)
weight_scale_layout = get_layout(weight_scale)
if (
not is_persistent
and weight_scale is not None
and (
weight_scale_layout is StridedLayout
or isinstance(weight_scale_layout, StridedLayout)
)
):
return max(selected_num_warps, 4)
return selected_num_warps

_compute_num_warps_sm120_mxfp4._sglang_sm120_mxfp4_patch = True
opt_flags_nvidia.compute_num_warps = _compute_num_warps_sm120_mxfp4

_sm120_mxfp4_min_warps_patched = True


def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
Expand All @@ -137,8 +181,8 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):

if is_sm120_supported():
# SM120 desktop Blackwell does not support the persistent/TMA MXFP4 path.
# This MXFP4 path uses StridedLayout and the non-persistent kernel with
# block_k=128 so the selected tile stays within the per-block shared-memory budget.
# This MXFP4 path uses StridedLayout and the non-persistent kernel.
_patch_sm120_mxfp4_min_warps()
from triton_kernels.tensor_details.layout import StridedLayout

value_layout = StridedLayout
Expand All @@ -147,7 +191,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
scale_layout_opts = {}
constraints = {
"is_persistent": False,
"block_k": 128,
"num_stages": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3602,7 +3602,11 @@ async def wait_for_zero(self):

@lru_cache(maxsize=1)
def is_triton_kernels_available() -> bool:
return importlib.util.find_spec("triton_kernels") is not None
triton_kernels_spec = importlib.util.find_spec("triton_kernels")
ragged_metadata_spec = importlib.util.find_spec(
"triton_kernels.tensor_details.ragged_tensor"
)
return triton_kernels_spec is not None and ragged_metadata_spec is not None


@lru_cache(maxsize=1)
Expand Down
14 changes: 9 additions & 5 deletions scripts/ci/cuda/ci_install_dependency.sh
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,15 @@ install_sglang_kernel() {
$PIP_CMD install "torch==${TORCH_VER}" "torchaudio==${TORCHAUDIO_VER}" "torchvision==${TORCHVISION_VER}" --index-url "https://download.pytorch.org/whl/${CU_VERSION}" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX
fi

# install_sglang above pulls sglang-kernel from PyPI, whose default wheel
# tracks one CUDA version (currently cu130). Force-reinstall from the
# CU_VERSION-matched sglang wheel index so runners on a different CUDA
# (e.g. h20 / cu129) get a wheel linked against the right libnvrtc.
$PIP_CMD install "sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT}" --index-url "https://docs.sglang.ai/whl/${CU_VERSION}/" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX
if [ "${CUSTOM_BUILD_SGL_KERNEL:-}" != "true" ]; then
# install_sglang above pulls sglang-kernel from PyPI, whose default wheel
# tracks one CUDA version (currently cu130). Force-reinstall from the
# CU_VERSION-matched sglang wheel index so runners on a different CUDA
# (e.g. h20 / cu129) get a wheel linked against the right libnvrtc.
$PIP_CMD install "sglang-kernel==${SGL_KERNEL_VERSION_FROM_SRT}" --index-url "https://docs.sglang.ai/whl/${CU_VERSION}/" --force-reinstall --no-deps $PIP_INSTALL_SUFFIX
else
echo "CUSTOM_BUILD_SGL_KERNEL=true: keeping freshly built sgl-kernel wheel."
fi

mark_step_done "${FUNCNAME[0]}"
}
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ FetchContent_Populate(repo-fmt)
FetchContent_Declare(
repo-triton
GIT_REPOSITORY "https://github.com/triton-lang/triton"
GIT_TAG v3.5.1
GIT_TAG v3.6.0
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-triton)
Expand Down
Loading