Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,20 +834,29 @@ def _verify_quantization(self) -> None:
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
# Allow auto-detection of quantization from checkpoint for draft model
# even if it differs from main model's quantization
if self.is_draft_model:
# Check if the CLI-specified quantization is compatible with HF config's quant_method
is_compatible = (
self.quantization in compatible_quantization_methods
and quant_method
in compatible_quantization_methods[self.quantization]
)
if is_compatible:
# Keep the CLI-specified quantization (e.g., modelopt_fp4) even if
# HF config says "modelopt" - they are compatible
logger.info(
f"Using CLI-specified quantization ({self.quantization}) which is "
f"compatible with HF config quant_method ({quant_method})."
)
elif self.is_draft_model:
# Allow auto-detection of quantization from checkpoint for draft model
# only if the CLI quantization is not compatible
logger.info(
f"Draft model quantization ({quant_method}) differs from "
f"main model quantization ({self.quantization}). "
f"Using draft model's detected quantization: {quant_method}"
)
self.quantization = quant_method
elif (
self.quantization not in compatible_quantization_methods
or quant_method
not in compatible_quantization_methods[self.quantization]
):
else:
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
get_quant_config,
gguf_quant_weights_iterator,
initialize_dummy_weights,
maybe_add_mtp_safetensors,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
np_cache_weights_iterator,
Expand Down Expand Up @@ -321,13 +322,17 @@ class Source:
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""

model_config: Optional["ModelConfig"] = None
"""The model configuration (for checking architecture, etc)."""

@classmethod
def init_new(cls, model_config: ModelConfig, model):
return cls(
model_config.model_path,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
model_config=model_config,
)

def __init__(self, load_config: LoadConfig):
Expand Down Expand Up @@ -471,6 +476,15 @@ def _get_weights_iterator(
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt
)

if use_safetensors and source.model_config is not None:
hf_weights_files = maybe_add_mtp_safetensors(
hf_weights_files,
hf_folder,
"model.safetensors.index.json",
source.model_config.hf_config,
)

if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
Expand Down
38 changes: 38 additions & 0 deletions python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,44 @@ def filter_duplicate_safetensors_files(
return hf_weights_files


def maybe_add_mtp_safetensors(
hf_weights_files: List[str], hf_folder: str, index_file: str, hf_config
) -> List[str]:
"""
Auto-detect and add mtp.safetensors for GLM4Moe MTP/NextN models if:
1. mtp.safetensors exists in the model directory
2. mtp.safetensors is NOT in the index (checkpoint packaging bug)
3. Model architecture is Glm4MoeForCausalLM with num_nextn_predict_layers > 0

This works around incorrectly packaged FP4 checkpoints like
baseten-admin/glm-4.7-fp4 where mtp.safetensors exists but
isn't referenced in model.safetensors.index.json.
"""
# Only apply for GLM4Moe architecture with nextn layers
arch = getattr(hf_config, "architectures", [None])[0]
num_nextn_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
if not (
arch in ["Glm4MoeForCausalLM", "Glm4MoeForCausalLMNextN"]
and num_nextn_layers > 0
):
return hf_weights_files

# Check if mtp.safetensors exists and is not already in the file list
mtp_path = os.path.join(hf_folder, "mtp.safetensors")
if not os.path.isfile(mtp_path) or mtp_path in hf_weights_files:
return hf_weights_files

# mtp.safetensors exists but not in index - this is a bug
logger.warning(
f"Found mtp.safetensors but it's not referenced in {index_file}. "
f"This is a checkpoint packaging bug. Auto-adding it for loading. "
f"Please report this to the checkpoint provider."
)

# Add it to the files list
return hf_weights_files + [mtp_path]


def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]:
"""
Exclude files that are not needed for inference.
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import filter_moe_weight_param_global_expert
from sglang.srt.layers.moe.utils import (
RoutingMethodType,
filter_moe_weight_param_global_expert,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down Expand Up @@ -376,6 +379,7 @@ def __init__(
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
routing_method_type=RoutingMethodType.DeepSeekV3,
prefix=add_prefix("experts", prefix),
)

Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sglang.srt.utils.common import (
LORA_TARGET_ALL_MODULES,
SUPPORTED_LORA_TARGET_MODULES,
check_pkg_version_at_least,
configure_ipv6,
cpu_has_amx_support,
get_bool_env_var,
Expand Down Expand Up @@ -1509,6 +1510,27 @@ def _handle_model_specific_adjustments(self):
)
self.disable_radix_cache = True
self.disable_overlap_schedule = False
elif model_arch in ["Glm4MoeForCausalLM"]:
if is_sm100_supported():
quantization_config = getattr(hf_config, "quantization_config", None)
quant_method = (
quantization_config.get("quant_method")
if quantization_config is not None
else None
)
if self.quantization is None and quant_method is not None:
self.quantization = quant_method
if (
self.quantization == "modelopt_fp4"
and self.moe_a2a_backend == "none"
and self.moe_runner_backend == "auto"
):
# Only enable flashinfer_trtllm if flashinfer-python version is >= 0.6.2
if check_pkg_version_at_least("flashinfer-python", "0.6.2"):
self.moe_runner_backend = "flashinfer_trtllm"
logger.info(
"Use flashinfer_trtllm as MoE runner backend on sm100 for Glm4MoeForCausalLM"
)
Comment on lines +1513 to +1533
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this if branch has been inserted to wrong place!

Qwen3NextForCausalLM 's Mamba radix cache v2 is inside Glm4MoeForCausalLM now

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jimmy-evo sorry about breaking this codes. It looks like it has been fixed in main though. I will be more careful next time.


# Mamba radix cache v2
if self.enable_mamba_extra_buffer():
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,24 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
)


def check_pkg_version_at_least(pkg: str, min_version: str) -> bool:
"""
Check if a package is installed and meets the minimum version requirement.

Args:
pkg: Package name (distribution name, e.g., "flashinfer-python")
min_version: Minimum version required (e.g., "0.6.2")

Returns:
True if package is installed and version >= min_version, False otherwise
"""
try:
installed_version = version(pkg)
return pkg_version.parse(installed_version) >= pkg_version.parse(min_version)
except PackageNotFoundError:
return False


def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the process and all its child processes."""
# Remove sigchld handler to avoid spammy logs.
Expand Down
Loading