Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
from contextlib import contextmanager
from enum import Enum, IntEnum
from functools import lru_cache
from typing import TYPE_CHECKING, Optional

from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
Expand Down Expand Up @@ -250,7 +249,6 @@ def get_tbo_token_distribution_threshold() -> float:
return TBO_TOKEN_DISTRIBUTION_THRESHOLD


@lru_cache(maxsize=1)
def should_use_flashinfer_cutlass_moe_fp4_allgather():
"""
Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
Expand Down Expand Up @@ -286,12 +284,21 @@ def speculative_moe_a2a_backend_context():
This ensures that draft models in speculative decoding use the configured speculative A2A backend.
"""
global MOE_A2A_BACKEND
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
original_backend = MOE_A2A_BACKEND
original_disable_flashinfer_cutlass_moe_fp4_allgather = (
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
)
try:
MOE_A2A_BACKEND = get_speculative_moe_a2a_backend()
# Disable FP4 allgather for spec decode since MTP layers are unquantized
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = True
yield
finally:
MOE_A2A_BACKEND = original_backend
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
original_disable_flashinfer_cutlass_moe_fp4_allgather
)


# The type of method in top-K routing, for use in torch custom op
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import should_use_flashinfer_cutlass_moe_fp4_allgather
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
Expand Down Expand Up @@ -1606,7 +1607,10 @@ def _slice_scale(w):
layer.dispatcher.set_quant_config(
{
"input_global_scale": (
layer.w13_input_scale_quant if MOE_NVFP4_DISPATCH else None
layer.w13_input_scale_quant
if MOE_NVFP4_DISPATCH
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else None
)
}
)
Expand Down
Loading