Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda, next_power_of_2
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2

if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
Expand Down Expand Up @@ -74,6 +74,10 @@
# Initialize logger for the module
logger = logging.getLogger(__name__)

CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
)

# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"]

Expand Down Expand Up @@ -1188,7 +1192,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
elif self.enable_flashinfer_cutedsl_moe:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
# All-expert-one-input-scale is mathematically different from default per-expert-input-scale
# Thus we allow users to switch the flag to do thorough testing
if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
w13_input_scale = (
layer.w13_input_scale.max()
.to(torch.float32)
.repeat(layer.w13_input_scale.shape[0])
)
else:
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
torch.float32
)

w2_input_scale = layer.w2_input_scale

def _slice_scale(w):
Expand Down
Loading