Skip to content
51 changes: 43 additions & 8 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def auto_quantize(
auto_quantize_method="gradient",
auto_quantize_score_size=128,
auto_quantize_checkpoint=None,
full_model: torch.nn.Module | None = None,
):
"""Auto search quantization of multiple formats."""

Expand Down Expand Up @@ -330,19 +331,49 @@ def auto_quantize(
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

def loss_func(output, data):
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
# which contains the loss attribute.
return output.loss
# When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
# use full_model's lm_head to compute logits/loss from hidden states.
is_base_model = (
full_model is not None
and language_model is not full_model
and not hasattr(language_model, "lm_head")
and hasattr(full_model, "lm_head")
)

if is_base_model:
assert full_model is not None
lm_head = full_model.lm_head

def loss_func(output, data):
logits = lm_head(output.last_hidden_state)
labels = data["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)

else:

def loss_func(output, data):
return output.loss

if auto_quantize_method == "gradient":
# For gradient-based method, return full output with loss

def forward_step(model, batch):
return model(**batch)
inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch
return model(**inputs)

elif auto_quantize_method == "kl_div":
# For KL divergence method, return only logits

def forward_step(model, batch):
return model(**batch).logits
inputs = {k: v for k, v in batch.items() if k != "labels"} if is_base_model else batch
output = model(**inputs)
if is_base_model:
assert full_model is not None
return full_model.lm_head(output.last_hidden_state)
return output.logits

else:
raise ValueError(
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
Expand Down Expand Up @@ -1022,6 +1053,10 @@ def quantize_main(
args,
language_model,
calib_dataloader,
auto_quantize_method=args.auto_quantize_method,
auto_quantize_score_size=args.auto_quantize_score_size,
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
full_model=full_model,
)

else:
Expand Down
14 changes: 13 additions & 1 deletion modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def get_experts_list(
linear_names = ["gate_proj", "down_proj", "up_proj"]
elif "nemotronhforcausallm" in model_type:
linear_names = ["up_proj", "down_proj"]
elif "gemma4" in model_type:
linear_names = ["gate_proj", "down_proj", "up_proj"]
else:
raise NotImplementedError(f" {model_type} not supported")

Expand Down Expand Up @@ -315,7 +317,14 @@ def is_moe(module: nn.Module) -> bool:
if name.endswith("sparsemoeblock") or "moelayer" in name:
return True
# Explicit matches for non-standard naming
return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn", "nemotronhmoe"])
if any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn", "nemotronhmoe"]):
return True
# Structural detection: modules with router + experts (e.g. Gemma4TextDecoderLayer)
return (
hasattr(module, "router")
and hasattr(module, "experts")
and isinstance(module.experts, nn.Module)
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def is_quantlinear(module: nn.Module) -> bool:
Expand Down Expand Up @@ -1007,6 +1016,9 @@ def module_match_name_list(module, name_list):
elif module_match_name_list(module, ["NemotronHMOE"]):
# NemotronHMOE experts (NemotronHMLP) use up_proj and down_proj only (no gate).
return ["up_proj", "down_proj"]
elif module_match_name_list(module, ["Gemma4TextDecoderLayer"]):
# Gemma4 MoE experts are unfused into per-expert nn.Linear layers
return ["gate_proj", "down_proj", "up_proj"]
else:
# assuming w1, w2, w3 by default
return ["w1", "w2", "w3"]
Expand Down
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,10 @@ def _nvfp4_selective_quant_cfg(
NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg(
["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True
)
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"])
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"])
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(
["*mlp.experts*", "*block_sparse_moe*", "*.experts.*"]
)
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*", "*.experts.*"])
NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"])

# DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to
Expand Down
16 changes: 16 additions & 0 deletions modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ quantize:
type: dynamic
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*.experts.*weight_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*.experts.*input_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*[kv]_bmm_quantizer'
enable: true
cfg:
Expand Down
115 changes: 115 additions & 0 deletions tests/unit/torch/export/test_layer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for modelopt.torch.export.layer_utils — MoE detection and expert naming."""

import pytest
import torch.nn as nn

from modelopt.torch.export.layer_utils import get_expert_linear_names, is_moe

# ---------------------------------------------------------------------------
# is_moe tests
# ---------------------------------------------------------------------------


class _FakeSparseMoeBlock(nn.Module):
"""Name ends with 'sparsemoeblock' — detected by naming convention."""


class _FakeMoeLayer(nn.Module):
"""Name contains 'moelayer' — detected by naming convention."""


class _FakeArcticMoe(nn.Module):
"""Name contains 'arcticmoe' — detected by explicit match."""


class _StructuralMoeModule(nn.Module):
"""Has router + experts attributes — detected by structural check."""

def __init__(self):
super().__init__()
self.router = nn.Linear(8, 4)
self.experts = nn.ModuleList([nn.Linear(8, 8) for _ in range(4)])


class _NotMoeModule(nn.Module):
"""Plain module — should NOT be classified as MoE."""

def __init__(self):
super().__init__()
self.fc = nn.Linear(8, 8)


class _PartialStructuralModule(nn.Module):
"""Has router but no experts — should NOT be classified as MoE."""

def __init__(self):
super().__init__()
self.router = nn.Linear(8, 4)


@pytest.mark.parametrize(
"module_cls",
[_FakeSparseMoeBlock, _FakeMoeLayer, _FakeArcticMoe],
)
def test_is_moe_name_based(module_cls):
assert is_moe(module_cls())


def test_is_moe_structural():
assert is_moe(_StructuralMoeModule())


def test_is_moe_negative():
assert not is_moe(_NotMoeModule())


def test_is_moe_partial_structural():
assert not is_moe(_PartialStructuralModule())


# ---------------------------------------------------------------------------
# get_expert_linear_names tests
# ---------------------------------------------------------------------------


class _FakeGemma4TextDecoderLayer(nn.Module):
pass


class _FakeMixtralSparseMoeBlock(nn.Module):
pass


class _FakeNemotronHMOE(nn.Module):
pass


def test_get_expert_linear_names_gemma4():
assert get_expert_linear_names(_FakeGemma4TextDecoderLayer()) == [
"gate_proj",
"down_proj",
"up_proj",
]


def test_get_expert_linear_names_mixtral():
assert get_expert_linear_names(_FakeMixtralSparseMoeBlock()) == ["w1", "w2", "w3"]


def test_get_expert_linear_names_nemotron():
assert get_expert_linear_names(_FakeNemotronHMOE()) == ["up_proj", "down_proj"]
Loading