Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.dtype == loaded_weight.dtype
), "init para dtype and loaded weight dtype should be the same"

assert param.size() == loaded_weight.size()
assert param.size() == loaded_weight.size(), (
f"ReplicatedLinear weight size mismatch: "
f"param.size()={list(param.size())}, "
f"loaded_weight.size()={list(loaded_weight.size())}, "
f"param.dtype={param.dtype}, "
f"loaded_weight.dtype={loaded_weight.dtype}"
)
param.data.copy_(loaded_weight)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
Expand Down
62 changes: 43 additions & 19 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,13 @@ def _get_quant_method(
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE

if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix, self.exclude_modules, self.packed_modules_mapping
) or self.is_layer_excluded(prefix):
skipped = is_layer_skipped(
prefix,
self.exclude_modules,
self.packed_modules_mapping or {},
)
excluded = self.is_layer_excluded(prefix)
if skipped or excluded:
return UnquantizedLinearMethod()
return Linear(self)
elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
Expand Down Expand Up @@ -1024,24 +1028,44 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
)

def is_layer_excluded(self, prefix: str):
import regex as re
import re

fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
prefix_split = prefix.split(".")
for pattern in self.exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
pattern_split = pattern.split(".")
if re.fullmatch(regex_str, prefix):
return True
elif (
pattern_split[-1] in fused_patterns
and pattern_split[-1] in prefix_split[-1]
):
# Check if the last part of the excluded pattern is contained in the last part of the prefix
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
assert len(prefix_split) == 5 and len(pattern_split) == 5
return True
# Build candidate prefixes to handle naming mismatches between
# SGLang model prefixes and checkpoint ignore patterns.
# E.g., Kimi K2.5 VLM: SGLang prefix is "model.layers.X.self_attn.Y"
# but checkpoint ignore patterns use "language_model.layers.X.self_attn*".
prefixes_to_check = [prefix]
if prefix.startswith("language_model.model."):
# language_model.model.X -> language_model.X (drop inner "model.")
prefixes_to_check.append(
"language_model." + prefix.removeprefix("language_model.model.")
)
# language_model.model.X -> model.X (drop "language_model.")
prefixes_to_check.append(prefix.removeprefix("language_model."))
elif prefix.startswith("model."):
# model.X -> language_model.X (replace "model." with "language_model.")
prefixes_to_check.append("language_model." + prefix.removeprefix("model."))
elif prefix.startswith("language_model."):
prefixes_to_check.append(prefix.removeprefix("language_model."))

for check_prefix in prefixes_to_check:
check_prefix_split = check_prefix.split(".")
for pattern in self.exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
pattern_split = pattern.split(".")
if re.fullmatch(regex_str, check_prefix):
return True
elif (
pattern_split[-1] in fused_patterns
and pattern_split[-1] in check_prefix_split[-1]
and len(check_prefix_split) == 5
and len(pattern_split) == 5
):
# Check if the last part of the excluded pattern is contained in the last part of the prefix
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
return True
return False

def get_quant_method(self, layer: torch.nn.Module, prefix: str):
Expand Down
Loading