Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ARG HOPPER_SBO=0
ARG HOPPER_SBO_DEEPEP_COMMIT=9f2fc4b3182a51044ae7ecb6610f7c9c3258c4d6
ARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee
ARG BUILD_AND_DOWNLOAD_PARALLEL=8
ARG SGL_KERNEL_VERSION=0.4.2.post1
ARG SGL_KERNEL_VERSION=0.4.2.post2
ARG SGL_VERSION
ARG SGL_DEEP_GEMM_VERSION=0.1.0
ARG USE_LATEST_SGLANG=0
Expand All @@ -19,7 +19,7 @@ ARG PIP_DEFAULT_INDEX
ARG UBUNTU_MIRROR
ARG GITHUB_ARTIFACTORY=github.com
ARG INSTALL_FLASHINFER_JIT_CACHE=0
ARG FLASHINFER_VERSION=0.6.8.post1
ARG FLASHINFER_VERSION=0.6.11.post1
ARG MOONCAKE_VERSION=0.3.10.post2
#if need other arg please add in MOONCAKE_COMPILE_ARG
ARG MOONCAKE_COMPILE_ARG="-DUSE_HTTP=ON -DUSE_MNNVL=ON -DUSE_CUDA=ON -DWITH_EP=ON"
Expand Down
10 changes: 5 additions & 5 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ dependencies = [
"datasets",
"einops",
"fastapi",
"flashinfer_python==0.6.8.post1", # keep it aligned with jit-cache version in Dockerfile
"flashinfer_cubin==0.6.8.post1",
"flashinfer_python==0.6.11.post1", # keep it aligned with jit-cache version in Dockerfile
"flashinfer_cubin==0.6.11.post1",
"gguf",
"interegular",
"llguidance>=0.7.11,<0.8.0",
Expand All @@ -37,7 +37,7 @@ dependencies = [
"ninja",
"easydict", # Required by remote model code (e.g. DeepSeek-OCR) loaded via trust_remote_code; validated by transformers 5.4+ check_imports
"numpy",
"nvidia-cutlass-dsl==4.4.2",
"nvidia-cutlass-dsl==4.5.0",
"nvidia-ml-py",
"openai-harmony==0.0.4",
"openai==2.6.1",
Expand All @@ -53,14 +53,14 @@ dependencies = [
"pydantic",
"python-multipart",
"pyzmq>=25.1.2",
"quack-kernels>=0.3.0",
"quack-kernels>=0.4.1",
"requests",
"scipy",
"sentencepiece",
"setproctitle",
"flash-attn-4>=4.0.0b9",
"sgl-deep-gemm==0.1.0",
"sglang-kernel==0.4.2.post1",
"sglang-kernel==0.4.2.post2",
"soundfile==0.13.1",
"tiktoken",
"tilelang==0.1.8",
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,15 +1201,15 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer":
assert_pkg_version(
"flashinfer_python",
"0.6.8.post1",
"0.6.11.post1",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
if _is_cuda:
assert_pkg_version(
"sglang-kernel",
"0.4.2.post1",
"0.4.2.post2",
"Please reinstall the latest version with `pip install sglang-kernel --force-reinstall`",
)

Expand Down
24 changes: 11 additions & 13 deletions python/sglang/srt/layers/flashinfer_comm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@ def initialize(
hidden_dim=hidden_dim,
dtype=dtype,
force_oneshot_support=bool(use_oneshot),
# Pin the symmetric-memory rendezvous to the actual
# subgroup. Without this, flashinfer >=0.6.10 falls back
# to WORLD and TP/EP/CP subgroup peers get addressed
# incorrectly (kernel hangs in cuda-graph warmup).
group=device_group,
)
if (
_TorchDistBackend is not None
Expand Down Expand Up @@ -515,8 +520,6 @@ def ensure_workspace_initialized(
if not is_flashinfer_available() or _flashinfer_comm is None:
return False

tp_coordinator = get_tp_group()

if use_attn_tp_group:
world_size = get_attn_tensor_model_parallel_world_size()
rank = get_attn_tensor_model_parallel_rank()
Expand All @@ -531,17 +534,12 @@ def ensure_workspace_initialized(
rank = get_moe_tensor_parallel_rank()
coordinator = get_moe_tp_group()

# When the sub-group IS the full TP group, pass None so the workspace
# uses the default process group directly (no TorchDistBackend needed).
# For true sub-groups, use NCCL device_group for GPU/device mapping and
# GLOO cpu_group for metadata broadcasts (avoids NCCL collectives that
# interfere with CUDA graph capture).
if coordinator.device_group is tp_coordinator.device_group:
device_group = None
cpu_group = None
else:
device_group = coordinator.device_group
cpu_group = coordinator.cpu_group
# Always pass the coordinator's groups: flashinfer >=0.6.10 reads the
# rendezvous group from `group=...` (falling back to WORLD when None),
# so leaving it None silently rendezvouses on WORLD and the kernel ends
# up addressing the wrong peers in TP/EP/CP subgroup setups.
device_group = coordinator.device_group
cpu_group = coordinator.cpu_group

if world_size <= 1:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
FlexCtx,
FnSpecs,
FusedActivation,
GatherIndx,
PrecisionConfig,
RoutingData,
ScatterIndx,
matmul_ogs,
)
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn

from sglang.srt.utils import is_cuda
Expand Down Expand Up @@ -297,9 +299,8 @@ def triton_kernel_fused_experts_with_bias(
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))

act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2),
(gemm1_alpha, gemm1_clamp_limit),
2,
)

intermediate_cache = torch.empty(
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/layers/moe/moe_runner/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
from sglang.srt.layers.moe.utils import MoeRunnerBackend

if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.matmul_ogs import (
GatherIndx,
PrecisionConfig,
RoutingData,
ScatterIndx,
)

from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
Expand Down
45 changes: 44 additions & 1 deletion python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,50 @@
import torch.nn.functional as F

try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
from triton_kernels.matmul_ogs import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.tensor import make_ragged_tensor_metadata
from triton_kernels.topk import topk as triton_kernels_topk

def routing(
logits,
n_expts_act,
sm_first=False,
expt_indx=None,
simulated_ep=1,
n_rows=None,
):
if simulated_ep != 1:
raise NotImplementedError(
"simulated_ep routing is not supported with triton_kernels 3.6.0"
)

if sm_first:
logits = torch.softmax(logits, dim=-1)

sparse_logits = triton_kernels_topk(
logits,
n_expts_act,
apply_softmax=not sm_first,
y_indx=expt_indx,
n_rows=n_rows,
)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_metadata.slice_sizes,
logits.shape[-1],
n_expts_act,
ragged_metadata,
)
gather_indx = GatherIndx(combine_indx, dispatch_indx)
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_indx, scatter_indx

except ImportError:
pass

Expand Down
14 changes: 7 additions & 7 deletions python/sglang/srt/layers/quantization/fp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def _flashinfer_fp4_quantize_impl(
enable_pdl: Optional[bool] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return _flashinfer_fp4_quantize(
input,
global_scale,
sf_vec_size,
sf_use_ue8m0,
is_sf_swizzled_layout,
is_sf_8x4_layout,
enable_pdl,
input=input,
global_scale=global_scale,
sf_vec_size=sf_vec_size,
sf_use_ue8m0=sf_use_ue8m0,
is_sf_swizzled_layout=is_sf_swizzled_layout,
is_sf_8x4_layout=is_sf_8x4_layout,
enable_pdl=enable_pdl,
backend=_flashinfer_fp4_quantize_backend,
)

Expand Down
49 changes: 46 additions & 3 deletions python/sglang/srt/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _get_flashinfer_mxfp4_device_permute_indices(
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_shuffle_moe_mxfp4 = is_gfx95_supported()
_sm120_mxfp4_min_warps_patched = False

if _is_hip:
# import aiter
Expand All @@ -156,6 +157,49 @@ def _get_flashinfer_mxfp4_device_permute_indices(
dynamic_mxfp4_quant = e8m0_shuffle = err


def _patch_sm120_mxfp4_min_warps():
global _sm120_mxfp4_min_warps_patched
if _sm120_mxfp4_min_warps_patched:
return

import inspect

from triton_kernels.matmul_ogs_details.opt_flags_details import opt_flags_nvidia
from triton_kernels.tensor import get_layout
from triton_kernels.tensor_details.layout import StridedLayout

compute_num_warps = opt_flags_nvidia.compute_num_warps
params = inspect.signature(compute_num_warps).parameters

if "is_persistent" in params and not getattr(
compute_num_warps, "_sglang_sm120_mxfp4_patch", False
):

def _compute_num_warps_sm120_mxfp4(
block_m, block_n, is_persistent, precision_config
):
selected_num_warps = compute_num_warps(
block_m, block_n, is_persistent, precision_config
)
weight_scale = getattr(precision_config, "weight_scale", None)
weight_scale_layout = get_layout(weight_scale)
if (
not is_persistent
and weight_scale is not None
and (
weight_scale_layout is StridedLayout
or isinstance(weight_scale_layout, StridedLayout)
)
):
return max(selected_num_warps, 4)
return selected_num_warps

_compute_num_warps_sm120_mxfp4._sglang_sm120_mxfp4_patch = True
opt_flags_nvidia.compute_num_warps = _compute_num_warps_sm120_mxfp4

_sm120_mxfp4_min_warps_patched = True


def _swizzle_mxfp4(quant_tensor, scale, num_warps):
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
Expand All @@ -165,8 +209,8 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):

if is_sm120_supported():
# SM120 desktop Blackwell does not support the persistent/TMA MXFP4 path.
# This MXFP4 path uses StridedLayout and the non-persistent kernel with
# block_k=128 so the selected tile stays within the per-block shared-memory budget.
# This MXFP4 path uses StridedLayout and the non-persistent kernel.
_patch_sm120_mxfp4_min_warps()
from triton_kernels.tensor_details.layout import StridedLayout

value_layout = StridedLayout
Expand All @@ -175,7 +219,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
scale_layout_opts = {}
constraints = {
"is_persistent": False,
"block_k": 128,
"num_stages": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
Expand Down
12 changes: 10 additions & 2 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def check_pkg_version_at_least(pkg: str, min_version: str) -> bool:

Args:
pkg: Package name (distribution name, e.g., "flashinfer-python")
min_version: Minimum version required (e.g., "0.6.8.post1")
min_version: Minimum version required (e.g., "0.6.11.post1")

Returns:
True if package is installed and version >= min_version, False otherwise
Expand Down Expand Up @@ -3661,7 +3661,15 @@ async def wait_for_zero(self):

@lru_cache(maxsize=1)
def is_triton_kernels_available() -> bool:
return importlib.util.find_spec("triton_kernels") is not None
if importlib.util.find_spec("triton_kernels") is None:
return False
try:
ragged_metadata_spec = importlib.util.find_spec(
"triton_kernels.tensor_details.ragged_tensor"
)
except ModuleNotFoundError:
return False
return ragged_metadata_spec is not None


@lru_cache(maxsize=1)
Expand Down
1 change: 0 additions & 1 deletion test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def test_gsm8k(self):
self.assertGreater(metrics["score"], 0.93)


@unittest.skip("broken on main, see #24816")
@unittest.skipUnless(
_flashinfer_has_sm90_cutlass_mxfp4(),
"FlashInfer build lacks SM90 mixed-input MXFP4 helpers (PR #3084, >= 0.6.11)",
Expand Down
15 changes: 9 additions & 6 deletions test/registered/moe/test_cutedsl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,13 +899,14 @@ def test_v1_masked_kernel_bf16_input(self):
masked_m.to(hidden_states.device),
)

a_global_scale = input_global_scale[:1]
a_fp4, a_scale_interleaved = fp4_quantize(
hidden_states, input_global_scale
hidden_states, a_global_scale
)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
input_global_scale,
a_global_scale,
dtype=hidden_states.dtype,
device=hidden_states.device,
block_size=16,
Expand Down Expand Up @@ -1077,11 +1078,12 @@ def test_v1_masked_kernel_rejects_v2_w13_layout(self):
masked_m.to(device),
)

a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
a_global_scale = input_global_scale[:1]
a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, a_global_scale)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
input_global_scale,
a_global_scale,
dtype=hidden_states.dtype,
device=device,
block_size=16,
Expand Down Expand Up @@ -1251,11 +1253,12 @@ def test_v1_masked_kernel_fp4_input(self):
)

# PyTorch reference (same as the bf16 input test)
a_fp4, a_scale = fp4_quantize(hidden_states, input_gs)
a_gs = input_gs[:1]
a_fp4, a_scale = fp4_quantize(hidden_states, a_gs)
a_deq = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale,
input_gs,
a_gs,
dtype=torch.bfloat16,
device=device,
block_size=16,
Expand Down
Loading
Loading