Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ Changelog
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.

0.44 (2026-05-18)
**Bug Fixes**

- Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance.

0.44 (2026-05-14)
^^^^^^^^^^^^^^^^^

**New Features**
Expand Down
9 changes: 9 additions & 0 deletions modelopt/torch/export/plugins/mcore_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
"linear_kv_up_proj": NameRemapping("model.layers.{}.self_attn.kv_b_proj."),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
# Fused TE spec (mirrors the import side). MLA has no linear_qkv so
# fused_input_layernorm is inert today; fused_pre_mlp_layernorm reaches dense layers.
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
# MLP for dense layers
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
Expand Down Expand Up @@ -88,6 +92,11 @@
"output_layer": NameRemapping("lm_head.", COL_TP),
# Per-layer
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
# MLA has no linear_qkv so fused_input_layernorm is inert for DeepSeek today; included
# for parity in case a future spec fuses the layernorm into a Q/KV projection.
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
"linear_q_proj": NameRemapping("model.layers.{}.self_attn.q_proj.", COL_TP),
"linear_q_down_proj": NameRemapping("model.layers.{}.self_attn.q_a_proj.", REPLICATE),
"linear_q_layernorm": NameRemapping("model.layers.{}.self_attn.q_a_layernorm.", REPLICATE),
Expand Down
6 changes: 6 additions & 0 deletions modelopt/torch/export/plugins/mcore_gptoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
gptoss_causal_lm_export: dict[str, CustomModuleMapping | bool] = {
"word_embeddings": NameRemapping("model.embed_tokens."),
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
# MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable.
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
"softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks"),
Expand All @@ -52,6 +54,10 @@
gptoss_causal_lm_import: dict[str, CustomModuleMapping | bool] = {
"word_embeddings": NameRemapping("model.embed_tokens.", COL_TP),
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
# gpt-oss is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's
# fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired.
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
"softmax_offset": NameRemapping("model.layers.{}.self_attn.sinks", COL_TP),
Expand Down
11 changes: 11 additions & 0 deletions modelopt/torch/export/plugins/mcore_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@
llama_causal_lm_export: dict[str, CustomModuleMapping] = {
"word_embeddings": NameRemapping("model.embed_tokens."),
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
# KV cache quant export
"core_attention": SelfAttentionScaling("model.layers.{}.self_attn."),
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
"final_layernorm": NameRemapping("model.norm."),
Expand All @@ -51,6 +53,8 @@
llama4_causal_lm_export: dict[str, CustomModuleMapping | bool] = {
"word_embeddings": NameRemapping("language_model.model.embed_tokens."),
"input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm."),
# MoE-only on MLP side, so fused_pre_mlp_layernorm path is unreachable.
"fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"),
# self_attn
"linear_qkv": QKVSlicing("language_model.model.layers.{}.self_attn."),
"linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj."),
Expand Down Expand Up @@ -150,9 +154,12 @@
llama_causal_lm_import: dict[str, CustomModuleMapping] = {
"word_embeddings": NameRemapping("model.embed_tokens.", COL_TP),
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
"final_layernorm": NameRemapping("model.norm.", REPLICATE),
Expand All @@ -162,6 +169,10 @@
llama4_causal_lm_import: dict[str, CustomModuleMapping | bool] = {
"word_embeddings": NameRemapping("language_model.model.embed_tokens.", COL_TP),
"input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.", REPLICATE),
# Fused TE spec (TELayerNormColumnParallelLinear) — see mcore_qwen.py for rationale.
# Llama4 is MoE-only on the MLP side (no layer.mlp.linear_fc1), so the importer's
# fused_pre_mlp_layernorm path is unreachable; only fused_input_layernorm is wired.
"fused_input_layernorm": NameRemapping("language_model.model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVMerging("language_model.model.layers.{}.self_attn.", COL_TP),
"linear_proj": NameRemapping("language_model.model.layers.{}.self_attn.o_proj.", ROW_TP),
"pre_mlp_layernorm": NameRemapping(
Expand Down
11 changes: 11 additions & 0 deletions modelopt/torch/export/plugins/mcore_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@
"output_layer": NameRemapping("lm_head.", COL_TP),
# Attention
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
# Fused TE spec (TELayerNormColumnParallelLinear): the LayerNorm weight lives on
# linear_qkv.layer_norm_weight, loaded directly from the HF norm tensor (no `.weight` suffix
# appended since the value is a Parameter, not a sub-module).
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
"q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm.", REPLICATE),
"k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm.", REPLICATE),
# MLP
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
# MoE
Expand All @@ -56,12 +61,14 @@
"output_layer": NameRemapping("lm_head."),
# Attention
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
"q_layernorm": NameRemapping("model.layers.{}.self_attn.q_norm."),
"k_layernorm": NameRemapping("model.layers.{}.self_attn.k_norm."),
# MLP
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
# MoE
Expand All @@ -76,10 +83,12 @@
"output_layer": NameRemapping("lm_head.", COL_TP),
# Attention
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm.", REPLICATE),
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVMerging("model.layers.{}.self_attn.", COL_TP),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj.", ROW_TP),
# MLP
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.", REPLICATE),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
"linear_fc1": GatedMLPMerging("model.layers.{}.mlp.", COL_TP),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj.", ROW_TP),
}
Expand All @@ -90,10 +99,12 @@
"output_layer": NameRemapping("lm_head."),
# Attention
"input_layernorm": NameRemapping("model.layers.{}.input_layernorm."),
"fused_input_layernorm": NameRemapping("model.layers.{}.input_layernorm.weight"),
"linear_qkv": QKVSlicing("model.layers.{}.self_attn."),
"linear_proj": NameRemapping("model.layers.{}.self_attn.o_proj."),
# MLP
"pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm."),
"fused_pre_mlp_layernorm": NameRemapping("model.layers.{}.post_attention_layernorm.weight"),
"linear_fc1": GatedMLPSlicing("model.layers.{}.mlp."),
"linear_fc2": NameRemapping("model.layers.{}.mlp.down_proj."),
}
74 changes: 61 additions & 13 deletions modelopt/torch/export/plugins/megatron_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,9 @@ def _gated_mlp_merging(
else:
prefix = prefix.replace("model", "mtp")

weight = module.state_dict().get("weight", None)
weight_scale = module.state_dict().get("weight_quantizer._scale", None)
module_state_dict = module.state_dict()
weight = module_state_dict.get("weight", None)
weight_scale = module_state_dict.get("weight_quantizer._scale", None)

state_dict = {}

Expand Down Expand Up @@ -273,6 +274,16 @@ def _gated_mlp_merging(
else:
state_dict["weight"] = tensor.to(self.dtype).to(device=weight.device)

# Preserve the fused LayerNorm weight + TE _extra_state already on the module so
# the strict load_state_dict below doesn't fail for TELayerNormColumnParallelLinear
# (fused under --export-default-te-spec). The actual HF norm tensor is loaded
# separately via the `fused_pre_mlp_layernorm` rule.
layer_norm_weight = module_state_dict.get("layer_norm_weight", None)
if layer_norm_weight is not None:
state_dict["layer_norm_weight"] = layer_norm_weight
if "_extra_state" in module_state_dict:
state_dict["_extra_state"] = module_state_dict["_extra_state"]

Comment thread
coderabbitai[bot] marked this conversation as resolved.
module.load_state_dict(state_dict)

def _grouped_mlp_merging(
Expand Down Expand Up @@ -433,7 +444,13 @@ def _qkv_merging(
layer_norm_weight = module_state_dict.get("layer_norm_weight", None)
if layer_norm_weight is not None:
state_dict["layer_norm_weight"] = layer_norm_weight
state_dict["_extra_state"] = None # for TE modules require _extra_state key
# Preserve the TE metadata struct (FP8 amax history, recipe version, etc.) —
# `load_state_dict(..., strict=True)` requires the key, but blanking it could
# zero out per-module FP8 bookkeeping on TE versions that populate it. Only
# forward through when the source actually has it, to avoid adding an
# unexpected `_extra_state=None` to TE variants that don't.
if "_extra_state" in module_state_dict:
state_dict["_extra_state"] = module_state_dict["_extra_state"]

module.load_state_dict(state_dict)

Expand Down Expand Up @@ -599,14 +616,32 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
)

# TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
# Prefer the per-context key (`fused_input_layernorm`); fall back to the legacy
# single-key `fused_norm` for Nemotron-H style (one norm shared across slots).
# Missing both is a plugin misconfig — raise rather than silently random-init.
if (
isinstance(layer.input_layernorm, IdentityOp)
and hasattr(attention, "linear_qkv")
and hasattr(attention.linear_qkv, "layer_norm_weight")
and "fused_norm" in self.rules
):
self.rules["fused_norm"](
fused_key = (
"fused_input_layernorm"
if "fused_input_layernorm" in self.rules
else "fused_norm"
)
if fused_key not in self.rules:
# Branch only fires when model uses fused TELayerNormColumnParallelLinear,
# so missing rule is unambiguously a plugin misconfiguration; raise so it
# doesn't silently ship a chance-accuracy checkpoint.
raise KeyError(
f"{self.arch} uses fused TELayerNormColumnParallelLinear for "
"attention but neither `fused_input_layernorm` nor legacy "
"`fused_norm` is in its import mapping; `linear_qkv.layer_norm_weight` "
"would be left at random init. Add "
'`fused_input_layernorm: NameRemapping("...input_layernorm.weight")` '
f"to the {self.arch} import mapping."
)
self.rules[fused_key](
attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp
)

Expand Down Expand Up @@ -707,14 +742,27 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp)

# TE spec: pre_mlp_layernorm is fused into linear_fc1
# (TELayerNormColumnParallelLinear).
# Load the fused layer_norm_weight from the HF norm path.
if (
isinstance(layer.pre_mlp_layernorm, IdentityOp)
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
and "fused_norm" in self.rules
# (TELayerNormColumnParallelLinear). See input_layernorm path above for the
# rule-key fallback rationale.
if isinstance(layer.pre_mlp_layernorm, IdentityOp) and hasattr(
layer.mlp.linear_fc1, "layer_norm_weight"
):
self.rules["fused_norm"](
fused_key = (
"fused_pre_mlp_layernorm"
if "fused_pre_mlp_layernorm" in self.rules
else "fused_norm"
)
if fused_key not in self.rules:
raise KeyError(
f"{self.arch} uses fused TELayerNormColumnParallelLinear for "
"MLP but neither `fused_pre_mlp_layernorm` nor legacy "
"`fused_norm` is in its import mapping; "
"`linear_fc1.layer_norm_weight` would be left at random init. "
"Add `fused_pre_mlp_layernorm: NameRemapping("
'"...post_attention_layernorm.weight")` '
f"to the {self.arch} import mapping."
)
self.rules[fused_key](
layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp
)

Expand Down
Loading
Loading