Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,20 @@ def _initialize_model(
load_config: LoadConfig,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)
model_class, model_arch = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
if (
quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
and model_arch == "DeepseekV3ForCausalLMNextN"
):
logger.warning(
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
)
quant_config = None
return model_class(
config=model_config.hf_config,
quant_config=quant_config,
Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2147,7 +2147,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0
if (
if self.quant_config is not None and (
self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16"
):
Comment on lines +2204 to 2207
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Good addition of the self.quant_config is not None check. This prevents potential AttributeError if self.quant_config happens to be None, which is now possible due to the changes in python/sglang/srt/model_loader/loader.py for the Deepseek R1 Fp4 model.

Expand Down Expand Up @@ -2178,6 +2178,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
for scale in ["k_scale", "v_scale"]:
if scale in name:
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding break here is a good optimization. It ensures the loop terminates as soon as the scale is found and name is updated, avoiding unnecessary iterations.

if name not in params_dict:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger.warning(f"{name} not found in params_dict.")
continue
Comment on lines +2236 to +2241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check to ensure name is in params_dict before attempting to access it is a robust addition. The warning message clearly explains why certain weights might be skipped, which is helpful for debugging and understanding model loading behavior, especially with modelopt checkpoints that might contain extra, unneeded weights for specific modules like MTP.

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
Expand Down
Loading