Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _load_w2(
# Narrow parameter and load.
if is_bias:
# this expert_data is a bias, not weight,
# for w2_bias in TP, it does not need to be sharded
# for w2_weight_bias in TP, it does not need to be sharded
shard_size = expert_data.shape[-1]
else:
# this parameter is a weight matrix
Expand All @@ -410,10 +410,6 @@ def _load_w2(
if not is_bias and not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
raise ValueError(
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
Expand Down Expand Up @@ -461,9 +457,25 @@ def weight_loader(
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
expert_id: Optional[int],
) -> None:

# if expert_id is None, then
# all the experts are loaded at the same time
if (
not expert_id
and self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
):
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return

global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
Expand Down Expand Up @@ -502,6 +514,7 @@ def _weight_loader_physical(
shard_id: str,
expert_id: int,
) -> None:

expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
Expand Down Expand Up @@ -705,6 +718,18 @@ def weight_loader_fused(
) -> None:
tp_rank = self.moe_tp_rank

if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)
elif "scale" in weight_name:
param.data.copy_(loaded_weight)
else:
dim1 = loaded_weight.shape[1]
dim2 = loaded_weight.shape[2]
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return

# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
Expand Down Expand Up @@ -854,6 +879,33 @@ def make_expert_params_mapping_fused(
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
]

@classmethod
def make_expert_params_mapping_fused_mxfp4(
cls,
ckpt_gate_up_proj_name: str,
ckpt_down_proj_name: str,
ckpt_gate_up_proj_bias_name: str,
ckpt_down_proj_bias_name: str,
ckpt_gate_up_proj_scale_name: str,
ckpt_down_proj_scale_name: str,
):
return [
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
(
"experts.w13_weight_bias",
f"experts.{ckpt_gate_up_proj_bias_name}",
"w13",
),
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
(
"experts.w13_weight_scale",
f"experts.{ckpt_gate_up_proj_scale_name}",
"w13",
),
("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
]

@classmethod
def make_expert_input_scale_params_mapping(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ def triton_kernel_fused_experts(
def triton_kernel_moe_with_bias_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_pcg,
b1: torch.Tensor,
w2: torch.Tensor,
w2_pcg,
b2: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
Expand All @@ -209,13 +211,15 @@ def triton_kernel_moe_with_bias_forward(

return triton_kernel_fused_experts_with_bias(
hidden_states,
w1,
b1,
w2,
b2,
routing_data,
gather_idx,
scatter_idx,
w1=w1,
w1_pcg=w1_pcg,
b1=b1,
w2=w2,
w2_pcg=w2_pcg,
b2=b2,
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=scatter_idx,
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
Expand All @@ -235,8 +239,10 @@ def triton_kernel_moe_with_bias_forward(
def triton_kernel_fused_experts_with_bias(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_pcg,
b1: torch.Tensor,
w2: torch.Tensor,
w2_pcg,
b2: torch.Tensor,
routing_data: RoutingData,
gather_indx: GatherIndx,
Expand Down Expand Up @@ -267,8 +273,10 @@ def triton_kernel_fused_experts_with_bias(

# type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
for w in (w1, w2):
# TODO assert bf16 or mxfp4
# assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
pass

# Shape check
assert hidden_states.ndim == 2, "hidden_states must be 2D"
Expand All @@ -287,13 +295,15 @@ def triton_kernel_fused_experts_with_bias(
if global_num_experts == -1:
global_num_experts = E

device = "cuda"
optg = dict()
w1, w1_flex = quantize(w1, "bf16", device, **optg)
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
# TODO maybe completely remove this branch
if w1.dtype == torch.bfloat16:
device = "cuda"
optg = dict()
w1, w1_flex = quantize(w1, "bf16", device, **optg)
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))

w2, w2_flex = quantize(w2, "bf16", device, **optg)
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
w2, w2_flex = quantize(w2, "bf16", device, **optg)
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))

act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
Expand Down
14 changes: 12 additions & 2 deletions python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def override_quantization_method(self, *args, **kwargs):
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from sglang.srt.utils import mxfp_supported
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported

is_mxfp_supported = mxfp_supported()
if is_mxfp_supported:
Expand All @@ -66,6 +66,7 @@ def override_quantization_method(self, *args, **kwargs):
ModelOptFp8Config,
)
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import get_linear_quant_method
Expand All @@ -90,7 +91,16 @@ def override_quantization_method(self, *args, **kwargs):
"w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config,
}
if is_mxfp_supported:


if is_cuda():
BASE_QUANTIZATION_METHODS.update(
{
"quark": Mxfp4Config,
"mxfp4": Mxfp4Config,
}
)
elif is_mxfp_supported and is_hip():
BASE_QUANTIZATION_METHODS.update(
{
"quark": MxFp4Config,
Expand Down
Loading
Loading