Skip to content
50 changes: 38 additions & 12 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,23 +1903,49 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
if self.n_share_experts_fusion > 0:
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []

Expand Down
Loading