diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d660c1de4c..5d06c50c47 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -292,6 +292,7 @@ def auto_quantize( auto_quantize_method="gradient", auto_quantize_score_size=128, auto_quantize_checkpoint=None, + full_model: torch.nn.Module | None = None, ): """Auto search quantization of multiple formats.""" @@ -330,19 +331,49 @@ def auto_quantize( for qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" - def loss_func(output, data): - # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` - # which contains the loss attribute. - return output.loss + # When language_model is a base text model without lm_head (e.g. Gemma4TextModel), + # use full_model's lm_head to compute logits/loss from hidden states. + is_base_model = ( + full_model is not None + and language_model is not full_model + and not hasattr(language_model, "lm_head") + and hasattr(full_model, "lm_head") + ) + + if is_base_model: + assert full_model is not None + lm_head = full_model.lm_head + + def loss_func(output, data): + logits = lm_head(output.last_hidden_state) + labels = data["labels"] + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + return torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + else: + + def loss_func(output, data): + return output.loss if auto_quantize_method == "gradient": - # For gradient-based method, return full output with loss + def forward_step(model, batch): - return model(**batch) + inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch + return model(**inputs) + elif auto_quantize_method == "kl_div": - # For KL divergence method, return only logits + def forward_step(model, batch): - return model(**batch).logits + inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch + output = model(**inputs) + if is_base_model: + assert full_model is not None + return full_model.lm_head(output.last_hidden_state) + return output.logits + else: raise ValueError( f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" @@ -1022,6 +1053,10 @@ def quantize_main( args, language_model, calib_dataloader, + auto_quantize_method=args.auto_quantize_method, + auto_quantize_score_size=args.auto_quantize_score_size, + auto_quantize_checkpoint=args.auto_quantize_checkpoint, + full_model=full_model, ) else: diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e8ee5afd45..b91001ffae 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -108,6 +108,8 @@ def get_experts_list( linear_names = ["gate_proj", "down_proj", "up_proj"] elif "nemotronhforcausallm" in model_type: linear_names = ["up_proj", "down_proj"] + elif "gemma4" in model_type: + linear_names = ["gate_proj", "down_proj", "up_proj"] else: raise NotImplementedError(f" {model_type} not supported") @@ -315,7 +317,14 @@ def is_moe(module: nn.Module) -> bool: if name.endswith("sparsemoeblock") or "moelayer" in name: return True # Explicit matches for non-standard naming - return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn", "nemotronhmoe"]) + if any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn", "nemotronhmoe"]): + return True + # Structural detection: modules with router + experts (e.g. Gemma4TextDecoderLayer) + return ( + hasattr(module, "router") + and hasattr(module, "experts") + and isinstance(module.experts, nn.Module) + ) def is_quantlinear(module: nn.Module) -> bool: @@ -1007,6 +1016,9 @@ def module_match_name_list(module, name_list): elif module_match_name_list(module, ["NemotronHMOE"]): # NemotronHMOE experts (NemotronHMLP) use up_proj and down_proj only (no gate). return ["up_proj", "down_proj"] + elif module_match_name_list(module, ["Gemma4TextDecoderLayer"]): + # Gemma4 MoE experts are unfused into per-expert nn.Linear layers + return ["gate_proj", "down_proj", "up_proj"] else: # assuming w1, w2, w3 by default return ["w1", "w2", "w3"] diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 34e7f692ca..6939c12558 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -799,8 +799,10 @@ def _nvfp4_selective_quant_cfg( NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg( ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True ) -NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"]) -NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"]) +NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg( + ["*mlp.experts*", "*block_sparse_moe*", "*.experts.*"] +) +NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*", "*.experts.*"]) NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"]) # DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml index 0222274af0..d4f9cce18c 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml @@ -53,6 +53,22 @@ quantize: type: dynamic scale_bits: e4m3 num_bits: e2m1 + - quantizer_name: '*.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 - quantizer_name: '*[kv]_bmm_quantizer' enable: true cfg: diff --git a/tests/unit/torch/export/test_layer_utils.py b/tests/unit/torch/export/test_layer_utils.py new file mode 100644 index 0000000000..93742c8fe7 --- /dev/null +++ b/tests/unit/torch/export/test_layer_utils.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for modelopt.torch.export.layer_utils — MoE detection and expert naming.""" + +import pytest +import torch.nn as nn + +from modelopt.torch.export.layer_utils import get_expert_linear_names, is_moe + +# --------------------------------------------------------------------------- +# is_moe tests +# --------------------------------------------------------------------------- + + +class _FakeSparseMoeBlock(nn.Module): + """Name ends with 'sparsemoeblock' — detected by naming convention.""" + + +class _FakeMoeLayer(nn.Module): + """Name contains 'moelayer' — detected by naming convention.""" + + +class _FakeArcticMoe(nn.Module): + """Name contains 'arcticmoe' — detected by explicit match.""" + + +class _StructuralMoeModule(nn.Module): + """Has router + experts attributes — detected by structural check.""" + + def __init__(self): + super().__init__() + self.router = nn.Linear(8, 4) + self.experts = nn.ModuleList([nn.Linear(8, 8) for _ in range(4)]) + + +class _NotMoeModule(nn.Module): + """Plain module — should NOT be classified as MoE.""" + + def __init__(self): + super().__init__() + self.fc = nn.Linear(8, 8) + + +class _PartialStructuralModule(nn.Module): + """Has router but no experts — should NOT be classified as MoE.""" + + def __init__(self): + super().__init__() + self.router = nn.Linear(8, 4) + + +@pytest.mark.parametrize( + "module_cls", + [_FakeSparseMoeBlock, _FakeMoeLayer, _FakeArcticMoe], +) +def test_is_moe_name_based(module_cls): + assert is_moe(module_cls()) + + +def test_is_moe_structural(): + assert is_moe(_StructuralMoeModule()) + + +def test_is_moe_negative(): + assert not is_moe(_NotMoeModule()) + + +def test_is_moe_partial_structural(): + assert not is_moe(_PartialStructuralModule()) + + +# --------------------------------------------------------------------------- +# get_expert_linear_names tests +# --------------------------------------------------------------------------- + + +class _FakeGemma4TextDecoderLayer(nn.Module): + pass + + +class _FakeMixtralSparseMoeBlock(nn.Module): + pass + + +class _FakeNemotronHMOE(nn.Module): + pass + + +def test_get_expert_linear_names_gemma4(): + assert get_expert_linear_names(_FakeGemma4TextDecoderLayer()) == [ + "gate_proj", + "down_proj", + "up_proj", + ] + + +def test_get_expert_linear_names_mixtral(): + assert get_expert_linear_names(_FakeMixtralSparseMoeBlock()) == ["w1", "w2", "w3"] + + +def test_get_expert_linear_names_nemotron(): + assert get_expert_linear_names(_FakeNemotronHMOE()) == ["up_proj", "down_proj"]