Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ def _verify_quantization(self) -> None:
compatible_quantization_methods = {
"modelopt_fp8": ["modelopt"],
"modelopt_fp4": ["modelopt"],
"modelopt_mxfp8": ["modelopt"],
"petit_nvfp4": ["modelopt"],
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.utils import is_cuda, is_sm90_supported, is_sm100_supported

# Workspace size required by CUTLASS grouped GEMM kernels (bytes).
# Used by both FP8 and MXFP8 MoE paths.
CUTLASS_MOE_WORKSPACE_BYTES = 90000

_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import (
Expand Down
24 changes: 22 additions & 2 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
CompressedTensorsMxInt4MoEMethod,
)
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptMxfp8MoEMethod,
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.server_args import get_global_server_args
Expand Down Expand Up @@ -768,8 +771,25 @@ def _weight_loader_impl(
return

if "ModelOpt" in self.quant_method.__class__.__name__:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
is_mxfp8_variant = isinstance(self.quant_method, ModelOptMxfp8MoEMethod)

if is_mxfp8_variant:
# MXFP8: weight_scale is block scale (uint8 UE8M0), not per-tensor
if "weight_scale" in weight_name or "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
else:
logger.warning(
"MXFP8 MoE: ignoring unrecognized weight %s",
weight_name,
)
return

# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions = (
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def override_quantization_method(self, *args, **kwargs):
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config,
ModelOptFp8Config,
ModelOptMxfp8Config,
)
from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
Expand All @@ -57,6 +58,7 @@ def override_quantization_method(self, *args, **kwargs):
"modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
"modelopt_fp8": ModelOptFp8Config,
"modelopt_fp4": ModelOptFp4Config,
"modelopt_mxfp8": ModelOptMxfp8Config,
"w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config,
"awq": AWQConfig,
Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,22 @@ def _modelopt_override_quantization_method(

# If user specified generic "modelopt", auto-detect the specific method
if user_quant == "modelopt":
if "FP8" in quant_algo:
# Check MXFP8 before FP8 since "MXFP8" contains "FP8"
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
elif "FP8" in quant_algo:
return "modelopt_fp8"
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
return "modelopt_fp4"

# The hf_quant_config may be a parsed quant config, so we need to check the
# quant_method.
if hf_quant_config.get("quant_method", "") == "modelopt_fp8":
quant_method = hf_quant_config.get("quant_method", "")
if quant_method == "modelopt_mxfp8":
return "modelopt_mxfp8"
elif quant_method == "modelopt_fp8":
return "modelopt_fp8"
elif hf_quant_config.get("quant_method", "") == "modelopt_fp4":
elif quant_method == "modelopt_fp4":
return "modelopt_fp4"

return None
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.cutlass_moe import CUTLASS_MOE_WORKSPACE_BYTES
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
FlashInferTrtllmFp8MoeQuantInfo,
Expand Down Expand Up @@ -1553,7 +1554,9 @@ def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
self.c_strides2 = torch.full(
(num_experts,), hidden_size, device=device, dtype=torch.int64
)
self.workspace = torch.empty(90000, device=device, dtype=torch.uint8)
self.workspace = torch.empty(
CUTLASS_MOE_WORKSPACE_BYTES, device=device, dtype=torch.uint8
)
self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
Expand Down
25 changes: 25 additions & 0 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,31 @@ def _pack_mxfp8_scales(scale_u8: torch.Tensor) -> torch.Tensor:
return packed.view(1, scale_m, scale_k, 2, 256)


def dequantize_mxfp8(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can borrow from triton instead of reimplementing this:

def _get_triton_mxfp8_upcast():
try:
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp_torch
except Exception as err:
raise RuntimeError(
"MXFP8 dequantization requires triton_kernels with MXFP8 support."
) from err
return upcast_from_mxfp_torch

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data: torch.Tensor, scale_u8: torch.Tensor, group_size: int = 32
) -> torch.Tensor:
"""Dequantize MXFP8 tensor with UE8M0 scales back to bf16.

Applies per-group scaling: fp8_val * 2^(scale - 127).

Args:
data: FP8 tensor of shape (M, K).
scale_u8: uint8 UE8M0 scales of shape (M, K // group_size).
group_size: Number of elements per scale group (default 32).

Returns:
bf16 tensor of shape (M, K).
"""
m, k = data.shape
n_groups = k // group_size
scales_f32 = torch.pow(
2.0, scale_u8.to(dtype=torch.float32, device=data.device) - 127.0
)
data_f32 = data.to(torch.float32).view(m, n_groups, group_size)
scales_f32 = scales_f32.view(m, n_groups, 1)
return (data_f32 * scales_f32).view(m, k).to(torch.bfloat16)


def triton_mxfp8_blockscaled_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down
Loading
Loading