diff --git a/docker/Dockerfile b/docker/Dockerfile index 1e9665a04219..ad03168e6cef 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,7 @@ ARG HOPPER_SBO=0 ARG HOPPER_SBO_DEEPEP_COMMIT=9f2fc4b3182a51044ae7ecb6610f7c9c3258c4d6 ARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee ARG BUILD_AND_DOWNLOAD_PARALLEL=8 -ARG SGL_KERNEL_VERSION=0.4.2.post1 +ARG SGL_KERNEL_VERSION=0.4.2.post2 ARG SGL_VERSION ARG SGL_DEEP_GEMM_VERSION=0.1.0 ARG USE_LATEST_SGLANG=0 @@ -19,7 +19,7 @@ ARG PIP_DEFAULT_INDEX ARG UBUNTU_MIRROR ARG GITHUB_ARTIFACTORY=github.com ARG INSTALL_FLASHINFER_JIT_CACHE=0 -ARG FLASHINFER_VERSION=0.6.8.post1 +ARG FLASHINFER_VERSION=0.6.11.post1 ARG MOONCAKE_VERSION=0.3.10.post2 #if need other arg please add in MOONCAKE_COMPILE_ARG ARG MOONCAKE_COMPILE_ARG="-DUSE_HTTP=ON -DUSE_MNNVL=ON -DUSE_CUDA=ON -DWITH_EP=ON" diff --git a/python/pyproject.toml b/python/pyproject.toml index 6e496bad2a38..7ec785be51d8 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,8 +27,8 @@ dependencies = [ "datasets", "einops", "fastapi", - "flashinfer_python==0.6.8.post1", # keep it aligned with jit-cache version in Dockerfile - "flashinfer_cubin==0.6.8.post1", + "flashinfer_python==0.6.11.post1", # keep it aligned with jit-cache version in Dockerfile + "flashinfer_cubin==0.6.11.post1", "gguf", "interegular", "llguidance>=0.7.11,<0.8.0", @@ -37,7 +37,7 @@ dependencies = [ "ninja", "easydict", # Required by remote model code (e.g. DeepSeek-OCR) loaded via trust_remote_code; validated by transformers 5.4+ check_imports "numpy", - "nvidia-cutlass-dsl==4.4.2", + "nvidia-cutlass-dsl==4.5.0", "nvidia-ml-py", "openai-harmony==0.0.4", "openai==2.6.1", @@ -53,14 +53,14 @@ dependencies = [ "pydantic", "python-multipart", "pyzmq>=25.1.2", - "quack-kernels>=0.3.0", + "quack-kernels>=0.4.1", "requests", "scipy", "sentencepiece", "setproctitle", "flash-attn-4>=4.0.0b9", "sgl-deep-gemm==0.1.0", - "sglang-kernel==0.4.2.post1", + "sglang-kernel==0.4.2.post2", "soundfile==0.13.1", "tiktoken", "tilelang==0.1.8", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f1788a03f00e..f96445c31055 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -1201,7 +1201,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.6.8.post1", + "0.6.11.post1", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", @@ -1209,7 +1209,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sglang-kernel", - "0.4.2.post1", + "0.4.2.post2", "Please reinstall the latest version with `pip install sglang-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/layers/flashinfer_comm_fusion.py b/python/sglang/srt/layers/flashinfer_comm_fusion.py index bca28f3e211c..c66c2cda4d8a 100644 --- a/python/sglang/srt/layers/flashinfer_comm_fusion.py +++ b/python/sglang/srt/layers/flashinfer_comm_fusion.py @@ -383,6 +383,11 @@ def initialize( hidden_dim=hidden_dim, dtype=dtype, force_oneshot_support=bool(use_oneshot), + # Pin the symmetric-memory rendezvous to the actual + # subgroup. Without this, flashinfer >=0.6.10 falls back + # to WORLD and TP/EP/CP subgroup peers get addressed + # incorrectly (kernel hangs in cuda-graph warmup). + group=device_group, ) if ( _TorchDistBackend is not None @@ -515,8 +520,6 @@ def ensure_workspace_initialized( if not is_flashinfer_available() or _flashinfer_comm is None: return False - tp_coordinator = get_tp_group() - if use_attn_tp_group: world_size = get_attn_tensor_model_parallel_world_size() rank = get_attn_tensor_model_parallel_rank() @@ -531,17 +534,12 @@ def ensure_workspace_initialized( rank = get_moe_tensor_parallel_rank() coordinator = get_moe_tp_group() - # When the sub-group IS the full TP group, pass None so the workspace - # uses the default process group directly (no TorchDistBackend needed). - # For true sub-groups, use NCCL device_group for GPU/device mapping and - # GLOO cpu_group for metadata broadcasts (avoids NCCL collectives that - # interfere with CUDA graph capture). - if coordinator.device_group is tp_coordinator.device_group: - device_group = None - cpu_group = None - else: - device_group = coordinator.device_group - cpu_group = coordinator.cpu_group + # Always pass the coordinator's groups: flashinfer >=0.6.10 reads the + # rendezvous group from `group=...` (falling back to WORLD when None), + # so leaving it None silently rendezvouses on WORLD and the kernel ends + # up addressing the wrong peers in TP/EP/CP subgroup setups. + device_group = coordinator.device_group + cpu_group = coordinator.cpu_group if world_size <= 1: return False diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index 166aca26c247..d5e2a6537b7d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -11,11 +11,13 @@ FlexCtx, FnSpecs, FusedActivation, + GatherIndx, PrecisionConfig, + RoutingData, + ScatterIndx, matmul_ogs, ) from triton_kernels.numerics import InFlexData -from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx from triton_kernels.swiglu import swiglu_fn from sglang.srt.utils import is_cuda @@ -297,9 +299,8 @@ def triton_kernel_fused_experts_with_bias( w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex)) act = FusedActivation( - FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), (gemm1_alpha, gemm1_clamp_limit), - 2, ) intermediate_cache = torch.empty( diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py b/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py index a90add0faaa8..258761e505de 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py @@ -19,8 +19,12 @@ from sglang.srt.layers.moe.utils import MoeRunnerBackend if TYPE_CHECKING: - from triton_kernels.matmul_ogs import PrecisionConfig - from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx + from triton_kernels.matmul_ogs import ( + GatherIndx, + PrecisionConfig, + RoutingData, + ScatterIndx, + ) from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardCombineInput, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 64b671194aa7..96031a9dc1c3 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -32,7 +32,50 @@ import torch.nn.functional as F try: - from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing + from triton_kernels.matmul_ogs import GatherIndx, RoutingData, ScatterIndx + from triton_kernels.tensor import make_ragged_tensor_metadata + from triton_kernels.topk import topk as triton_kernels_topk + + def routing( + logits, + n_expts_act, + sm_first=False, + expt_indx=None, + simulated_ep=1, + n_rows=None, + ): + if simulated_ep != 1: + raise NotImplementedError( + "simulated_ep routing is not supported with triton_kernels 3.6.0" + ) + + if sm_first: + logits = torch.softmax(logits, dim=-1) + + sparse_logits = triton_kernels_topk( + logits, + n_expts_act, + apply_softmax=not sm_first, + y_indx=expt_indx, + n_rows=n_rows, + ) + dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx + combine_indx = sparse_logits.mask_metadata.col_sorted_indx + ragged_metadata = make_ragged_tensor_metadata( + sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0] + ) + gate_scal = sparse_logits.vals.flatten()[combine_indx] + routing_data = RoutingData( + gate_scal, + ragged_metadata.slice_sizes, + logits.shape[-1], + n_expts_act, + ragged_metadata, + ) + gather_indx = GatherIndx(combine_indx, dispatch_indx) + scatter_indx = ScatterIndx(dispatch_indx, combine_indx) + return routing_data, gather_indx, scatter_indx + except ImportError: pass diff --git a/python/sglang/srt/layers/quantization/fp4_utils.py b/python/sglang/srt/layers/quantization/fp4_utils.py index 96409750cbf7..a7a64f25e99e 100644 --- a/python/sglang/srt/layers/quantization/fp4_utils.py +++ b/python/sglang/srt/layers/quantization/fp4_utils.py @@ -34,13 +34,13 @@ def _flashinfer_fp4_quantize_impl( enable_pdl: Optional[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor]: return _flashinfer_fp4_quantize( - input, - global_scale, - sf_vec_size, - sf_use_ue8m0, - is_sf_swizzled_layout, - is_sf_8x4_layout, - enable_pdl, + input=input, + global_scale=global_scale, + sf_vec_size=sf_vec_size, + sf_use_ue8m0=sf_use_ue8m0, + is_sf_swizzled_layout=is_sf_swizzled_layout, + is_sf_8x4_layout=is_sf_8x4_layout, + enable_pdl=enable_pdl, backend=_flashinfer_fp4_quantize_backend, ) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 70425a73cfeb..a6b41db17be9 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -141,6 +141,7 @@ def _get_flashinfer_mxfp4_device_permute_indices( _is_hip = is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_shuffle_moe_mxfp4 = is_gfx95_supported() +_sm120_mxfp4_min_warps_patched = False if _is_hip: # import aiter @@ -156,6 +157,49 @@ def _get_flashinfer_mxfp4_device_permute_indices( dynamic_mxfp4_quant = e8m0_shuffle = err +def _patch_sm120_mxfp4_min_warps(): + global _sm120_mxfp4_min_warps_patched + if _sm120_mxfp4_min_warps_patched: + return + + import inspect + + from triton_kernels.matmul_ogs_details.opt_flags_details import opt_flags_nvidia + from triton_kernels.tensor import get_layout + from triton_kernels.tensor_details.layout import StridedLayout + + compute_num_warps = opt_flags_nvidia.compute_num_warps + params = inspect.signature(compute_num_warps).parameters + + if "is_persistent" in params and not getattr( + compute_num_warps, "_sglang_sm120_mxfp4_patch", False + ): + + def _compute_num_warps_sm120_mxfp4( + block_m, block_n, is_persistent, precision_config + ): + selected_num_warps = compute_num_warps( + block_m, block_n, is_persistent, precision_config + ) + weight_scale = getattr(precision_config, "weight_scale", None) + weight_scale_layout = get_layout(weight_scale) + if ( + not is_persistent + and weight_scale is not None + and ( + weight_scale_layout is StridedLayout + or isinstance(weight_scale_layout, StridedLayout) + ) + ): + return max(selected_num_warps, 4) + return selected_num_warps + + _compute_num_warps_sm120_mxfp4._sglang_sm120_mxfp4_patch = True + opt_flags_nvidia.compute_num_warps = _compute_num_warps_sm120_mxfp4 + + _sm120_mxfp4_min_warps_patched = True + + def _swizzle_mxfp4(quant_tensor, scale, num_warps): """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" import triton_kernels.matmul_ogs_details.opt_flags as opt_flags @@ -165,8 +209,8 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): if is_sm120_supported(): # SM120 desktop Blackwell does not support the persistent/TMA MXFP4 path. - # This MXFP4 path uses StridedLayout and the non-persistent kernel with - # block_k=128 so the selected tile stays within the per-block shared-memory budget. + # This MXFP4 path uses StridedLayout and the non-persistent kernel. + _patch_sm120_mxfp4_min_warps() from triton_kernels.tensor_details.layout import StridedLayout value_layout = StridedLayout @@ -175,7 +219,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): scale_layout_opts = {} constraints = { "is_persistent": False, - "block_k": 128, "num_stages": 1, } opt_flags.update_opt_flags_constraints(constraints) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0ac387e3210c..48e108665b68 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1105,7 +1105,7 @@ def check_pkg_version_at_least(pkg: str, min_version: str) -> bool: Args: pkg: Package name (distribution name, e.g., "flashinfer-python") - min_version: Minimum version required (e.g., "0.6.8.post1") + min_version: Minimum version required (e.g., "0.6.11.post1") Returns: True if package is installed and version >= min_version, False otherwise @@ -3661,7 +3661,15 @@ async def wait_for_zero(self): @lru_cache(maxsize=1) def is_triton_kernels_available() -> bool: - return importlib.util.find_spec("triton_kernels") is not None + if importlib.util.find_spec("triton_kernels") is None: + return False + try: + ragged_metadata_spec = importlib.util.find_spec( + "triton_kernels.tensor_details.ragged_tensor" + ) + except ModuleNotFoundError: + return False + return ragged_metadata_spec is not None @lru_cache(maxsize=1) diff --git a/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py b/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py index 6e0d1c4eab73..bc1469fe1b3f 100644 --- a/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py +++ b/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py @@ -91,7 +91,6 @@ def test_gsm8k(self): self.assertGreater(metrics["score"], 0.93) -@unittest.skip("broken on main, see #24816") @unittest.skipUnless( _flashinfer_has_sm90_cutlass_mxfp4(), "FlashInfer build lacks SM90 mixed-input MXFP4 helpers (PR #3084, >= 0.6.11)", diff --git a/test/registered/moe/test_cutedsl_moe.py b/test/registered/moe/test_cutedsl_moe.py index ed54c206dd66..08ddd7db7b93 100644 --- a/test/registered/moe/test_cutedsl_moe.py +++ b/test/registered/moe/test_cutedsl_moe.py @@ -899,13 +899,14 @@ def test_v1_masked_kernel_bf16_input(self): masked_m.to(hidden_states.device), ) + a_global_scale = input_global_scale[:1] a_fp4, a_scale_interleaved = fp4_quantize( - hidden_states, input_global_scale + hidden_states, a_global_scale ) a_in_dtype = dequantize_nvfp4_to_dtype( a_fp4, a_scale_interleaved, - input_global_scale, + a_global_scale, dtype=hidden_states.dtype, device=hidden_states.device, block_size=16, @@ -1077,11 +1078,12 @@ def test_v1_masked_kernel_rejects_v2_w13_layout(self): masked_m.to(device), ) - a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale) + a_global_scale = input_global_scale[:1] + a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, a_global_scale) a_in_dtype = dequantize_nvfp4_to_dtype( a_fp4, a_scale_interleaved, - input_global_scale, + a_global_scale, dtype=hidden_states.dtype, device=device, block_size=16, @@ -1251,11 +1253,12 @@ def test_v1_masked_kernel_fp4_input(self): ) # PyTorch reference (same as the bf16 input test) - a_fp4, a_scale = fp4_quantize(hidden_states, input_gs) + a_gs = input_gs[:1] + a_fp4, a_scale = fp4_quantize(hidden_states, a_gs) a_deq = dequantize_nvfp4_to_dtype( a_fp4, a_scale, - input_gs, + a_gs, dtype=torch.bfloat16, device=device, block_size=16, diff --git a/test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py b/test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py index add4cdd12eb3..85ca0fd27f94 100644 --- a/test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py +++ b/test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py @@ -20,12 +20,7 @@ from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci( - est_time=120, - stage="stage-b", - runner_config="1-gpu-large", - disabled="broken on main, see #24816", -) +register_cuda_ci(est_time=120, stage="stage-b", runner_config="1-gpu-large") flashinfer_fused_moe = pytest.importorskip("flashinfer.fused_moe")