From a526962753425a474d0727bdb9dc5cd360dd8eea Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 12 Dec 2025 13:40:14 -0800 Subject: [PATCH 1/2] Fix fp4 allgather for spec decode --- python/sglang/srt/layers/moe/utils.py | 6 ++++-- python/sglang/srt/layers/quantization/modelopt_quant.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index a220318a9258..b48e267aa294 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -3,7 +3,6 @@ import logging from contextlib import contextmanager from enum import Enum, IntEnum -from functools import lru_cache from typing import TYPE_CHECKING, Optional from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size @@ -250,7 +249,6 @@ def get_tbo_token_distribution_threshold() -> float: return TBO_TOKEN_DISTRIBUTION_THRESHOLD -@lru_cache(maxsize=1) def should_use_flashinfer_cutlass_moe_fp4_allgather(): """ Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving. @@ -286,12 +284,16 @@ def speculative_moe_a2a_backend_context(): This ensures that draft models in speculative decoding use the configured speculative A2A backend. """ global MOE_A2A_BACKEND + global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER original_backend = MOE_A2A_BACKEND try: MOE_A2A_BACKEND = get_speculative_moe_a2a_backend() + # Disable FP4 allgather for spec decode since MTP layers are unquantized + DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = True yield finally: MOE_A2A_BACKEND = original_backend + DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = False # The type of method in top-K routing, for use in torch custom op diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 84406d947796..4abb1fda0047 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -22,6 +22,7 @@ ) from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo +from sglang.srt.layers.moe.utils import should_use_flashinfer_cutlass_moe_fp4_allgather from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -1606,7 +1607,10 @@ def _slice_scale(w): layer.dispatcher.set_quant_config( { "input_global_scale": ( - layer.w13_input_scale_quant if MOE_NVFP4_DISPATCH else None + layer.w13_input_scale_quant + if MOE_NVFP4_DISPATCH + or should_use_flashinfer_cutlass_moe_fp4_allgather() + else None ) } ) From d0341bbefe0880e1983d4f102000b2876a481847 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 17 Dec 2025 00:09:03 +0000 Subject: [PATCH 2/2] Preserve original disable value --- python/sglang/srt/layers/moe/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index b48e267aa294..414140fda277 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -286,6 +286,9 @@ def speculative_moe_a2a_backend_context(): global MOE_A2A_BACKEND global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER original_backend = MOE_A2A_BACKEND + original_disable_flashinfer_cutlass_moe_fp4_allgather = ( + DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER + ) try: MOE_A2A_BACKEND = get_speculative_moe_a2a_backend() # Disable FP4 allgather for spec decode since MTP layers are unquantized @@ -293,7 +296,9 @@ def speculative_moe_a2a_backend_context(): yield finally: MOE_A2A_BACKEND = original_backend - DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = False + DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = ( + original_disable_flashinfer_cutlass_moe_fp4_allgather + ) # The type of method in top-K routing, for use in torch custom op