diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7dd63c29ecd..997a1069401 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -26,7 +26,11 @@ Changelog - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. -0.44 (2026-05-18) +**Bug Fixes** + +- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. + +0.44 (2026-05-14) ^^^^^^^^^^^^^^^^^ **New Features** diff --git a/modelopt/torch/export/plugins/mcore_deepseek.py b/modelopt/torch/export/plugins/mcore_deepseek.py index d02259e3530..2f9ef40f08e 100644 --- a/modelopt/torch/export/plugins/mcore_deepseek.py +++ b/modelopt/torch/export/plugins/mcore_deepseek.py @@ -43,6 +43,10 @@ "linear_kv_up_proj": NameRemapping("model.layers.{}.self_attn.kv_b_proj."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + # Fused TE spec (mirrors the import side). MLA has no linear_qkv so + # fused_input_layernorm is inert today; fused_pre_mlp_layernorm reaches dense layers. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), # MLP for dense layers "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), @@ -88,6 +92,11 @@ "output_layer": NameRemapping("lm_head.", COL_TP), # Per-layer "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + # MLA has no linear_qkv so fused_input_layernorm is inert for DeepSeek today; included + # for parity in case a future spec fuses the layernorm into a Q/KV projection. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_q_proj": NameRemapping("model.layers.{}.self_attn.q_proj.", COL_TP), "linear_q_down_proj": NameRemapping("model.layers.{}.self_attn.q_a_proj.", REPLICATE), "linear_q_layernorm": NameRemapping("model.layers.{}.self_attn.q_a_layernorm.", REPLICATE), diff --git a/modelopt/torch/export/plugins/mcore_gptoss.py b/modelopt/torch/export/plugins/mcore_gptoss.py index c16347fbf0b..989aa7e67d7 100644 --- a/modelopt/torch/export/plugins/mcore_gptoss.py +++ b/modelopt/torch/export/plugins/mcore_gptoss.py @@ -31,6 +31,8 @@ gptoss_causal_lm_export: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("model.embed_tokens."), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + # MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), "softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks"), @@ -52,6 +54,10 @@ gptoss_causal_lm_import: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("model.embed_tokens.", COL_TP), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + # gpt-oss is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's + # fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), "softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks", COL_TP), diff --git a/modelopt/torch/export/plugins/mcore_llama.py b/modelopt/torch/export/plugins/mcore_llama.py index 7fb8ec76acf..80a5d9146a9 100644 --- a/modelopt/torch/export/plugins/mcore_llama.py +++ b/modelopt/torch/export/plugins/mcore_llama.py @@ -37,11 +37,13 @@ llama_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("model.embed_tokens."), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), # KV cache quant export "core_attention": SelfAttentionScaling("model.layers.{}.self_attn."), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), "final_layernorm": NameRemapping("model.norm."), @@ -51,6 +53,8 @@ llama4_causal_lm_export: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("language_model.model.embed_tokens."), "input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm."), + # MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable. + "fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"), # self_attn "linear_qkv": QKVSlicing("language_model.model.layers.{}.self_attn."), "linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj."), @@ -150,9 +154,12 @@ llama_causal_lm_import: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("model.embed_tokens.", COL_TP), "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP), "final_layernorm": NameRemapping("model.norm.", REPLICATE), @@ -162,6 +169,10 @@ llama4_causal_lm_import: dict[str, CustomModuleMapping | bool] = { "word_embeddings": NameRemapping("language_model.model.embed_tokens.", COL_TP), "input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale. + # Llama4 is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's + # fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired. + "fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("language_model.model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj.", ROW_TP), "pre_mlp_layernorm": NameRemapping( diff --git a/modelopt/torch/export/plugins/mcore_qwen.py b/modelopt/torch/export/plugins/mcore_qwen.py index 5c4ae0647d8..4120a9a36d9 100644 --- a/modelopt/torch/export/plugins/mcore_qwen.py +++ b/modelopt/torch/export/plugins/mcore_qwen.py @@ -35,12 +35,17 @@ "output_layer": NameRemapping("lm_head.", COL_TP), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + # Fused TE spec (TELayerNormColumnParallelLinear): the LayerNorm weight lives on + # linear_qkv.layer_norm_weight, loaded directly from the HF norm tensor (no `.weight` suffix + # appended since the value is a Parameter, not a sub-module). + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), "q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm.", REPLICATE), "k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm.", REPLICATE), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP), # MoE @@ -56,12 +61,14 @@ "output_layer": NameRemapping("lm_head."), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), "q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm."), "k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm."), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), # MoE @@ -76,10 +83,12 @@ "output_layer": NameRemapping("lm_head.", COL_TP), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP), } @@ -90,10 +99,12 @@ "output_layer": NameRemapping("lm_head."), # Attention "input_layernorm": NameRemapping("model.layers.{}.input_layernorm."), + "fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"), "linear_qkv": QKVSlicing("model.layers.{}.self_attn."), "linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."), # MLP "pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."), + "fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"), "linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."), "linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index e485731b3d8..7be98e6416b 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -238,8 +238,9 @@ def _gated_mlp_merging( else: prefix = prefix.replace("model", "mtp") - weight = module.state_dict().get("weight", None) - weight_scale = module.state_dict().get("weight_quantizer._scale", None) + module_state_dict = module.state_dict() + weight = module_state_dict.get("weight", None) + weight_scale = module_state_dict.get("weight_quantizer._scale", None) state_dict = {} @@ -273,6 +274,16 @@ def _gated_mlp_merging( else: state_dict["weight"] = tensor.to(self.dtype).to(device=weight.device) + # Preserve the fused LayerNorm weight + TE _extra_state already on the module so + # the strict load_state_dict below doesn't fail for TELayerNormColumnParallelLinear + # (fused under --export-default-te-spec). The actual HF norm tensor is loaded + # separately via the `fused_pre_mlp_layernorm` rule. + layer_norm_weight = module_state_dict.get("layer_norm_weight", None) + if layer_norm_weight is not None: + state_dict["layer_norm_weight"] = layer_norm_weight + if "_extra_state" in module_state_dict: + state_dict["_extra_state"] = module_state_dict["_extra_state"] + module.load_state_dict(state_dict) def _grouped_mlp_merging( @@ -433,7 +444,13 @@ def _qkv_merging( layer_norm_weight = module_state_dict.get("layer_norm_weight", None) if layer_norm_weight is not None: state_dict["layer_norm_weight"] = layer_norm_weight - state_dict["_extra_state"] = None # for TE modules require _extra_state key + # Preserve the TE metadata struct (FP8 amax history, recipe version, etc.) — + # `load_state_dict(..., strict=True)` requires the key, but blanking it could + # zero out per-module FP8 bookkeeping on TE versions that populate it. Only + # forward through when the source actually has it, to avoid adding an + # unexpected `_extra_state=None` to TE variants that don't. + if "_extra_state" in module_state_dict: + state_dict["_extra_state"] = module_state_dict["_extra_state"] module.load_state_dict(state_dict) @@ -599,14 +616,32 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = ) # TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear). - # Load the fused layer_norm_weight from the HF norm path. + # Prefer the per-context key (`fused_input_layernorm`); fall back to the legacy + # single-key `fused_norm` for Nemotron-H style (one norm shared across slots). + # Missing both is a plugin misconfig — raise rather than silently random-init. if ( isinstance(layer.input_layernorm, IdentityOp) and hasattr(attention, "linear_qkv") and hasattr(attention.linear_qkv, "layer_norm_weight") - and "fused_norm" in self.rules ): - self.rules["fused_norm"]( + fused_key = ( + "fused_input_layernorm" + if "fused_input_layernorm" in self.rules + else "fused_norm" + ) + if fused_key not in self.rules: + # Branch only fires when model uses fused TELayerNormColumnParallelLinear, + # so missing rule is unambiguously a plugin misconfiguration; raise so it + # doesn't silently ship a chance-accuracy checkpoint. + raise KeyError( + f"{self.arch} uses fused TELayerNormColumnParallelLinear for " + "attention but neither `fused_input_layernorm` nor legacy " + "`fused_norm` is in its import mapping; `linear_qkv.layer_norm_weight` " + "would be left at random init. Add " + '`fused_input_layernorm: NameRemapping("...input_layernorm.weight")` ' + f"to the {self.arch} import mapping." + ) + self.rules[fused_key]( attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp ) @@ -707,14 +742,27 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) # TE spec: pre_mlp_layernorm is fused into linear_fc1 - # (TELayerNormColumnParallelLinear). - # Load the fused layer_norm_weight from the HF norm path. - if ( - isinstance(layer.pre_mlp_layernorm, IdentityOp) - and hasattr(layer.mlp.linear_fc1, "layer_norm_weight") - and "fused_norm" in self.rules + # (TELayerNormColumnParallelLinear). See input_layernorm path above for the + # rule-key fallback rationale. + if isinstance(layer.pre_mlp_layernorm, IdentityOp) and hasattr( + layer.mlp.linear_fc1, "layer_norm_weight" ): - self.rules["fused_norm"]( + fused_key = ( + "fused_pre_mlp_layernorm" + if "fused_pre_mlp_layernorm" in self.rules + else "fused_norm" + ) + if fused_key not in self.rules: + raise KeyError( + f"{self.arch} uses fused TELayerNormColumnParallelLinear for " + "MLP but neither `fused_pre_mlp_layernorm` nor legacy " + "`fused_norm` is in its import mapping; " + "`linear_fc1.layer_norm_weight` would be left at random init. " + "Add `fused_pre_mlp_layernorm: NameRemapping(" + '"...post_attention_layernorm.weight")` ' + f"to the {self.arch} import mapping." + ) + self.rules[fused_key]( layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp ) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 23b8cfd1630..44529ba0fad 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -426,25 +426,33 @@ def _get_state_dict(self): if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: self.rules["output_layer"](model.output_layer) - def _get_fused_norm_weight(self, module): - """Return ``module.layer_norm_weight`` when TE fuses the norm into a linear layer. - - Returns ``None`` when the ``"fused_norm"`` rule is absent or the module has no - ``layer_norm_weight`` attribute (or its value is ``None``). + def _get_fused_norm_weight(self, module, primary_key: str = "fused_norm"): + """Return ``(rule_key, layer_norm_weight)`` when TE fuses the norm into a linear layer. + + Mirrors the importer-side fallback chain: prefer the per-context key + (``fused_input_layernorm`` for attention, ``fused_pre_mlp_layernorm`` for MLP) and + fall back to the legacy ``fused_norm`` rule (Nemotron-H style, one norm shared + across attention/mlp/mamba slots). Returns ``(None, None)`` when no rule is + defined or the module has no ``layer_norm_weight``. """ - if "fused_norm" not in self.rules: - return None - return getattr(module, "layer_norm_weight", None) + fused_key = primary_key if primary_key in self.rules else "fused_norm" + if fused_key not in self.rules: + return None, None + weight = getattr(module, "layer_norm_weight", None) + if weight is None: + return None, None + return fused_key, weight def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.input_layernorm, IdentityOp): self.rules["input_layernorm"](layer.input_layernorm, layer_id) - elif ( - norm_weight := self._get_fused_norm_weight( - getattr(layer.self_attention, "linear_qkv", None) + else: + fused_key, norm_weight = self._get_fused_norm_weight( + getattr(layer.self_attention, "linear_qkv", None), + primary_key="fused_input_layernorm", ) - ) is not None: - self.rules["fused_norm"](norm_weight, layer_id) + if norm_weight is not None: + self.rules[fused_key](norm_weight, layer_id) if not isinstance(layer.self_attention, IdentityOp): if "MLASelfAttention" in str(type(layer.self_attention)): @@ -483,13 +491,13 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) - elif ( - not isinstance(layer.mlp, IdentityOp) - and "MoE" not in str(type(layer.mlp)) - and (norm_weight := self._get_fused_norm_weight(getattr(layer.mlp, "linear_fc1", None))) - is not None - ): - self.rules["fused_norm"](norm_weight, layer_id) + elif not isinstance(layer.mlp, IdentityOp) and "MoE" not in str(type(layer.mlp)): + fused_key, norm_weight = self._get_fused_norm_weight( + getattr(layer.mlp, "linear_fc1", None), + primary_key="fused_pre_mlp_layernorm", + ) + if norm_weight is not None: + self.rules[fused_key](norm_weight, layer_id) if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): @@ -597,9 +605,12 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]: def _get_mamba_layer_state_dict(self, layer, layer_id): if not isinstance(layer.norm, IdentityOp): self.rules["norm"](layer.norm, layer_id) - elif (norm_weight := self._get_fused_norm_weight(layer.mixer.in_proj)) is not None: + else: # TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear). - self.rules["fused_norm"](norm_weight, layer_id) + # Mamba uses the legacy single-key `fused_norm` rule (Nemotron-H style). + fused_key, norm_weight = self._get_fused_norm_weight(layer.mixer.in_proj) + if norm_weight is not None: + self.rules[fused_key](norm_weight, layer_id) self.rules["mixer_norm"](layer.mixer.norm, layer_id) self.rules["A_log"](layer.mixer.A_log, layer_id) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 19f836f38dc..b993060c20e 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -79,6 +79,20 @@ except ImportError: HAS_MAMBA = False +# Newer Megatron-LM instantiates Nemotron-H et al. as plain HybridModel (MambaModel split +# out as a subclass). Register HybridModel so the dynamic-space converter sees them. +# DMRegistry._get_registered_nn_class filters by `nn_cls.forward is nn_cls_.forward` and +# returns the first match in insertion order: MambaModel is registered first, so +# MambaModel instances dispatch to MambaModel whether or not MambaModel overrides forward. +try: + from megatron.core.models.hybrid.hybrid_model import HybridModel + + SUPPORTED_MODELS[HybridModel] = "megatron.core.models.hybrid.HybridModel" + + HAS_HYBRID = True +except ImportError: + HAS_HYBRID = False + __all__ = ["get_te_mamba_stack_spec"] @@ -394,6 +408,9 @@ def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: Trace lambda mod, val: (num_attention_heads.active + 2 * mod.config.num_query_groups) * mod.config.kv_channels, ) + # in_features must track input_size so TE's forward-time inp_shape[-1] == in_features + # assertion holds when hidden_size is pruned. + self._register_dynamic_attribute("in_features", lambda mod, val: mod.input_size) self._register_dynamic_attribute("weight", self._get_weight) # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0: diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 6a4e3828750..6fa90f96af4 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -37,7 +37,6 @@ import torch.nn as nn import torch.nn.functional as F from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear -from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, get_pipeline_model_parallel_rank, @@ -56,6 +55,7 @@ from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.nas.plugins.megatron import ( + HAS_HYBRID, HAS_MAMBA, SUPPORTED_MODELS, _DynamicMambaLayer, @@ -173,6 +173,20 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i model.config.num_layers = new_num_layers +def _get_hybrid_pattern_key(model: nn.Module) -> str | None: + """Return the attribute name carrying the hybrid block pattern for hybrid models, else None. + + Handles both ``MambaModel`` (which still uses ``hybrid_override_pattern``) and plain + ``HybridModel`` (the parent class introduced in modern Megatron-LM, which carries + ``hybrid_layer_pattern``). Detecting by attribute presence avoids fragile isinstance + checks against a class hierarchy that may shift across MCore versions. + """ + for attr in ("hybrid_override_pattern", "hybrid_layer_pattern"): + if getattr(model, attr, None): + return attr + return None + + def _rprint(*renderables: Any) -> None: """Render rich renderables and print on rank 0 only.""" buf = io.StringIO() @@ -366,14 +380,9 @@ def run_search(self) -> None: # Prune homogeneously self._prune(export_config, prune_depth=True) - # TODO: Rename to hybrid_layer_pattern after MCore 0.17 and nemo:26.04 is released (for M-LM PR #3377) - # Update hybrid_override_pattern if pruning is done on a hybrid model - if isinstance(self.model, MambaModel): - hybrid_key = ( - "hybrid_override_pattern" - if hasattr(self.model, "hybrid_override_pattern") - else "hybrid_layer_pattern" - ) + # Update the hybrid block-type pattern if pruning a hybrid model. + hybrid_key = _get_hybrid_pattern_key(self.model) + if hybrid_key is not None: print_rank_0(f"Original {hybrid_key}: {getattr(self.model, hybrid_key)}") new_num_layers = self.model.config.num_layers assert self.sorted_layers is not None @@ -683,14 +692,9 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di model = self.model active_metric_keys = self.constraints.keys() & _METRIC_CONSTRAINTS - # Get hybrid layer pattern for MambaModel (None for pure GPT) hybrid_layer_pattern: str | None = None - if isinstance(model, MambaModel): - hybrid_key = ( - "hybrid_override_pattern" - if hasattr(self.model, "hybrid_override_pattern") - else "hybrid_layer_pattern" - ) + hybrid_key = _get_hybrid_pattern_key(model) + if hybrid_key is not None: hybrid_layer_pattern = getattr(model, hybrid_key) # If depth pruning on a hybrid model, filter the pattern to only the kept layers. @@ -732,6 +736,14 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di return metrics +_HYBRID_DIVISORS = { + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 512, + "mamba_head_dim_divisor": 8, + "num_moe_experts_divisor": 8, + "num_layers_divisor": 2, +} + MCoreMinitronConfig: type[ModeloptBaseConfig] = create_model( "MCoreMinitronConfig", **get_kwargs_for_create_model_with_rules( @@ -743,19 +755,8 @@ def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> di "num_moe_experts_divisor": 8, "num_layers_divisor": 2, }, - **( - { - "megatron.core.models.mamba.MambaModel": { - "hidden_size_divisor": 256, - "ffn_hidden_size_divisor": 512, - "mamba_head_dim_divisor": 8, - "num_moe_experts_divisor": 8, - "num_layers_divisor": 2, - } - } - if HAS_MAMBA - else {} - ), + **({"megatron.core.models.mamba.MambaModel": _HYBRID_DIVISORS} if HAS_MAMBA else {}), + **({"megatron.core.models.hybrid.HybridModel": _HYBRID_DIVISORS} if HAS_HYBRID else {}), }, doc='Configuration for the ``"mcore_minitron"`` mode.', ), diff --git a/modelopt/torch/utils/logging.py b/modelopt/torch/utils/logging.py index f5aba0d1a1f..724dda14434 100644 --- a/modelopt/torch/utils/logging.py +++ b/modelopt/torch/utils/logging.py @@ -111,8 +111,14 @@ def print_rank_0(*args, **kwargs): def warn_rank_0(message, *args, **kwargs): - """Issues a warning only on the master process.""" + """Issues a warning only on the master process. + + Auto-bumps ``stacklevel`` by 1 to skip this wrapper frame, so callers can pass the + same stacklevel they would to ``warnings.warn`` directly and the warning still + points at the user's call site. + """ if dist.is_master(): + kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 1 warnings.warn(message, *args, **kwargs) diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index 5625013bb44..1c4c5d54647 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -150,7 +150,9 @@ def megatron_prefill( ) send_to_next_pipeline_rank(output.to(dtype=pp_dtype)) - logits = output[:, :seq_length, :].detach() if pp_last else None + # .contiguous() is required because the slice is a view with the padded stride; the broadcast + # below asserts contiguity when SP pads seq_length up to a multiple of TP. + logits = output[:, :seq_length, :].detach().contiguous() if pp_last else None if model.config.bf16: logits_dtype = torch.bfloat16 diff --git a/modelopt/torch/utils/plugins/megatron_mmlu.py b/modelopt/torch/utils/plugins/megatron_mmlu.py index 4a07405caff..6c70c5aee48 100644 --- a/modelopt/torch/utils/plugins/megatron_mmlu.py +++ b/modelopt/torch/utils/plugins/megatron_mmlu.py @@ -60,6 +60,7 @@ def megatron_mmlu( few_shots: int = 0, fraction: float = 0.05, batch_size: int = 1, + mmlu_dataset: str = "cais/mmlu", ) -> float: """Evaluate the model on MMLU using log-likelihood scoring over batched prefill passes. @@ -73,6 +74,8 @@ def megatron_mmlu( few_shots: The number of few-shot examples to use. fraction: The fraction of the test set to evaluate on. batch_size: Number of examples to process in one forward pass. + mmlu_dataset: HF dataset name or local MMLU dataset path passed to `datasets.load_dataset`. + Defaults to ``cais/mmlu``. """ print_rank_0( f"\nMMLU ({fraction * 100}%, {few_shots}-shot, Batch Size: {batch_size}) evaluation started...\n" @@ -104,8 +107,8 @@ def _generate_prompt(test_example, dev_examples, few_shots=0): # Load all subjects in two dataset calls instead of 2x num_subjects calls. # The "all" config includes a "subject" field for per-subject reporting. - test_dataset = load_dataset("cais/mmlu", "all", split="test") - dev_dataset = load_dataset("cais/mmlu", "all", split="dev") if few_shots > 0 else None + test_dataset = load_dataset(mmlu_dataset, "all", split="test") + dev_dataset = load_dataset(mmlu_dataset, "all", split="dev") if few_shots > 0 else None # Group dev examples by subject for few-shot prompt construction. dev_by_subject: dict = {} diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml index ff55a92e39f..6ae64fc1ff4 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml @@ -28,7 +28,7 @@ pipeline: calib_dataset: abisee/cnn_dailymail calib_size: 32 mmlu_dataset: cais/mmlu - mmlu_lower_bound: 0.68 + mmlu_lower_bound: 0.75 hf_local: /hf-local/ slurm_config: _factory_: "slurm_factory"