Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Changelog
``ModelOptArgParser`` adds ``--config`` YAML support with CLI overrides and auto-generates ``ARGUMENTS.md`` from dataclass definitions.
Dataset blending (``configs/dataset/blend.yaml``) supports HuggingFace datasets, local JSON/JSONL/Parquet files, and weighted multi-source blends.
The legacy FSDP1 accelerate config is removed; ``llm_qat`` now documents FSDP2, DeepSpeed, and DDP backends.
- The PTQ example scripts ``examples/llm_ptq/hf_ptq.py``, ``examples/llm_ptq/multinode_ptq.py`` and ``examples/megatron_bridge/quantize.py`` now derive their ``--qformat`` / ``--kv_cache_qformat`` (``--quant_cfg`` / ``--kv_cache_quant`` for Megatron-Bridge) CLI vocabularies by discovering the YAML presets under ``modelopt_recipes/configs/ptq/presets/{model,kv}/`` rather than carrying hardcoded ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` tables. The two ``llm_ptq`` scripts share the discovery helper via ``examples/llm_ptq/example_utils.py``; the Megatron-Bridge script keeps its own copy. Presets are loaded eagerly into a plain dict at import. Adding a new preset YAML makes it available on the CLI of all three with no script change — note this means each script now accepts every preset under those directories, not just a previously curated subset. All previously-supported short names (``int8_sq``, ``nvfp4_awq``, ``fp8_pb_wo``, ``nvfp4_mse``, ``w4a8_awq``, ``nvfp4_local_hessian``, ``fp8_pc_pt``, ``int8_wo``, and the Megatron-Bridge ``fp8_blockwise``) keep working via a small deprecation alias table; new formats should be exposed as preset YAMLs (or, longer term, as full ``--recipe`` recipes).
- Add ``configs/ptq/presets/kv/fp8_cast.yaml`` and ``configs/ptq/presets/kv/nvfp4_cast.yaml``, promoting ``fp8_cast`` / ``nvfp4_cast`` to first-class KV presets composed from the existing ``kv_fp8_cast`` / ``kv_nvfp4_cast`` unit fragments. The previous runtime ``use_constant_amax`` post-edit in ``hf_ptq.py`` is removed; ``use_constant_amax: true`` now lives in the YAML and is therefore authoritative. **Custom (out-of-tree) recipes that target a cast KV format must set ``use_constant_amax: true`` themselves on the ``[kv]_bmm_quantizer`` config** — in-tree recipes already do via the ``kv_*_cast`` units.

**Bug Fixes**

Expand Down
96 changes: 95 additions & 1 deletion examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import shutil
import sys
import warnings
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Mapping
from pathlib import Path
from typing import Any

Expand All @@ -47,11 +47,105 @@
except ImportError:
snapshot_download = None

from modelopt.recipe import load_config
from modelopt.torch.opt.config_loader import BUILTIN_CONFIG_ROOT
from modelopt.torch.quantization.config import QuantizeConfig

logger = logging.getLogger(__name__)

SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]


# Preset directories under modelopt_recipes/ that back the --qformat and
# --kv_cache_qformat CLI vocabularies of the llm_ptq example scripts (hf_ptq.py,
# multinode_ptq.py). Each ``*.yaml`` file in these directories is automatically
# discovered and exposed as a valid CLI value via _load_preset_cfg_choices, so no
# code change is required when a YAML is added or removed. This is deliberate: every
# preset YAML is CLI-exposed, there is no separate allow-list — the directory
# listing is the policy.
#
# That said, prefer NOT to add new YAMLs to these preset directories either. The
# long-term direction is to retire --qformat / --kv_cache_qformat entirely in favour
# of --recipe, which accepts a full PTQ recipe (see modelopt_recipes/general/ptq/
# and modelopt/recipe/). New quantization configurations should be authored as
# recipes, not as preset entries.
_QFORMAT_PRESET_DIR = "configs/ptq/presets/model"
_KV_QFORMAT_PRESET_DIR = "configs/ptq/presets/kv"

# Backward-compat short names → canonical preset basename. These aliases predate the
# YAML-driven discovery below and remain accepted so existing scripts keep working.
#
# DO NOT add new entries here. New quantization formats must be exposed via their YAML
# basename under modelopt_recipes/configs/ptq/presets/model/ — the directory listing is
# the canonical CLI vocabulary. This table exists solely to keep pre-existing short
# names (and the scripts/docs that hardcode them) working through deprecation, and
# should only ever shrink.
_QFORMAT_ALIASES: dict[str, str] = {
"int8_sq": "int8_smoothquant",
"int8_wo": "int8_weight_only",
"w4a8_awq": "w4a8_awq_beta",
"nvfp4_awq": "nvfp4_awq_lite",
"nvfp4_mse": "nvfp4_w4a4_weight_mse_fp8_sweep",
"nvfp4_local_hessian": "nvfp4_w4a4_weight_local_hessian",
"fp8_pb_wo": "fp8_2d_blockwise_weight_only",
"fp8_pc_pt": "fp8_per_channel_per_token",
}

# Sentinel value for ``--kv_cache_qformat`` meaning "no KV cache quantization".
_KV_NONE = "none"


def _load_preset_cfg_choices(
subdir: str, aliases: Mapping[str, str] | None = None
) -> dict[str, dict[str, Any]]:
"""Build a ``{qformat_name: quant_cfg_dict}`` mapping from the preset YAMLs.

Every ``*.yaml`` under ``modelopt_recipes/<subdir>/`` is loaded and keyed by its
basename — the directory listing is the CLI vocabulary. ``aliases`` adds extra
short names pointing at canonical basenames; a stale alias raises here (at import)
rather than failing silently at lookup time.

Configs are loaded eagerly into a plain dict. Callers that mutate a returned
config must deepcopy it first (``build_quant_cfg`` and the other call sites
already do); this mirrors how the previous ``mtq.*_CFG`` module constants were
used. A lazy / copy-on-access variant can be reintroduced later if load time
ever becomes a concern.
"""
aliases = aliases or {}
basenames = sorted(
entry.name.rsplit(".", 1)[0]
for entry in BUILTIN_CONFIG_ROOT.joinpath(subdir).iterdir()
if entry.name.endswith((".yaml", ".yml"))
)
choices: dict[str, dict[str, Any]] = {
name: load_config(f"{subdir}/{name}", schema_type=QuantizeConfig).model_dump(
exclude_unset=True
)
for name in basenames
}
for alias, target in sorted(aliases.items()):
if target not in choices:
raise ValueError(
f"Alias {alias!r} points at preset {target!r} which is not present "
f"under modelopt_recipes/{subdir}/."
)
choices[alias] = choices[target]
return choices


QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = _load_preset_cfg_choices(
_QFORMAT_PRESET_DIR, _QFORMAT_ALIASES
)
Comment thread
shengliangxu marked this conversation as resolved.
Outdated
KV_QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = _load_preset_cfg_choices(_KV_QFORMAT_PRESET_DIR)

# Guard against a future ``none.yaml`` (or alias) colliding with the disable sentinel:
# argparse would silently allow both, but the runtime branch on ``!= _KV_NONE`` would
# become ambiguous and the user couldn't reach the real preset.
assert _KV_NONE not in KV_QUANT_CFG_CHOICES, (
f"_KV_NONE sentinel {_KV_NONE!r} collides with a KV preset; rename the preset."
)


def run_nemotron_vl_preview(
full_model,
tokenizer,
Expand Down
179 changes: 84 additions & 95 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4
from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static
from example_utils import (
_KV_NONE,
_QFORMAT_ALIASES,
KV_QUANT_CFG_CHOICES,
QUANT_CFG_CHOICES,
build_quant_cfg,
copy_custom_model_files,
create_vlm_calibration_loop,
Expand Down Expand Up @@ -86,56 +90,67 @@
RAND_SEED = 1234


def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"""Set use_constant_amax on KV cache quantizers.
def _kv_cfg_uses_constant_amax(kv_quant_cfg: list[dict[str, Any]]) -> bool:
"""Return True if this KV cfg pins ``use_constant_amax`` on the bmm quantizer.

Creates a new dict for the KV bmm quantizer config to avoid mutating shared references.
Cast-style KV presets (e.g. ``fp8_cast`` / ``nvfp4_cast``) set
``use_constant_amax: true`` on the ``*[kv]_bmm_quantizer`` entry; that flag
means there is no data-driven calibration to run, so callers should skip
the KV-only calibration pass. Detect the property from the YAML contents
rather than from the preset name so new cast-style presets work
automatically.
"""
for i, entry in enumerate(quant_cfg):
for entry in kv_quant_cfg:
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
continue
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
break


QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a16_nvfp4": mtq.W4A16_NVFP4_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,
}

KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8_cast": "FP8_KV_CFG",
"fp8": "FP8_KV_CFG",
"fp8_affine": "FP8_AFFINE_KV_CFG",
"nvfp4_cast": "NVFP4_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
}

# Formats that use use_constant_amax (no calibration needed).
_KV_CAST_FORMATS = {"fp8_cast", "nvfp4_cast"}
return bool(cfg.get("use_constant_amax"))
return False


# Formats supported by mtq.auto_quantize unified-checkpoint export.
#
# This stays hardcoded — and intentionally not derived from the preset directory —
# because auto_quantize compatibility is a property of the export path (the unified
# HF checkpoint writer, TRT-LLM consumer constraints, layer-wise mixing rules), not
# of the YAML itself. A preset can exist and be valid for plain PTQ while not being
# safe to mix into an auto_quantize search. Update this set when adding/removing a
# format from auto_quantize support.
#
# NOTE: auto_quantize is being refactored/reimplemented; this table and the
# _canonical_qformat helper below are expected to be removed in the near future, so
# deliberately not invested in deriving them from the presets.
_AUTO_QUANTIZE_QFORMATS: frozenset[str] = frozenset(
{
"fp8",
"int8_smoothquant",
"int8_weight_only",
"int4_awq",
"nvfp4",
"nvfp4_awq_lite",
"nvfp4_w4a4_weight_mse_fp8_sweep",
"w4a8_awq_beta",
"fp8_2d_blockwise_weight_only",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_w4a4_weight_local_hessian",
"mxfp8",
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)


def _canonical_qformat(name: str) -> str:
"""Resolve a user-provided qformat token to its canonical preset basename.

Lets membership checks (e.g. against :data:`_AUTO_QUANTIZE_QFORMATS`) accept
either the short alias (``int8_sq``) or the canonical YAML basename
(``int8_smoothquant``). Unknown tokens pass through unchanged so the existing
error paths still fire.
"""
return _QFORMAT_ALIASES.get(name, name)


mto.enable_huggingface_checkpointing()

Expand Down Expand Up @@ -311,27 +326,11 @@ def auto_quantize(

qformat_list = args.qformat.split(",")
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
# Check if all provided quantization formats are supported. Canonicalize first so
# callers may pass either the short alias (``int8_sq``) or the canonical YAML
# basename (``int8_smoothquant``).
assert all(
qformat
in [
"fp8",
"int8_sq",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_local_hessian",
"mxfp8",
]
for qformat in qformat_list
_canonical_qformat(qformat) in _AUTO_QUANTIZE_QFORMATS for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

# When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
Expand Down Expand Up @@ -417,21 +416,16 @@ def forward_step(model, batch):

calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
# We need to explicitly set up KV cache quantization after auto_quantize
enable_quant_kv_cache = args.kv_cache_qformat != "none"
enable_quant_kv_cache = args.kv_cache_qformat != _KV_NONE
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
if enable_quant_kv_cache:
kv_cache_quant_cfg = copy.deepcopy(
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)
kv_cache_quant_cfg = copy.deepcopy(KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"])
kv_cache_quant_cfg = [
e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*"
] # keep other quantizers from auto_quantize

if args.kv_cache_qformat in _KV_CAST_FORMATS:
_set_kv_cache_constant_amax(kv_cache_quant_cfg)

mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
if args.kv_cache_qformat not in _KV_CAST_FORMATS:
if not _kv_cfg_uses_constant_amax(kv_cache_quant_cfg):
# Calibrate only the KV cache quantizers; disable all others.
with mtq.set_quantizer_by_cfg_context(
language_model,
Expand All @@ -455,21 +449,14 @@ def load_model(args: argparse.Namespace):
)
else:
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
f"Quantization format is not supported for low memory mode. Supported formats: {list(QUANT_CFG_CHOICES)}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
if args.kv_cache_qformat != "none":
if args.kv_cache_qformat != _KV_NONE:
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"],
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Mirror the use_constant_amax logic from quantize_main so that init_quantized_weights
# builds the KV quantizers with use_constant_amax already set. In calibration_only mode
# mtq.calibrate() does not re-apply quant_cfg, so this must happen before
# init_quantized_weights runs.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

# Do not use real quant GEMM so the calibration can be more accurate.
with init_quantized_weights(
Expand Down Expand Up @@ -1103,7 +1090,7 @@ def _is_layerwise(obj):
)

assert args.qformat in QUANT_CFG_CHOICES, (
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES)}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]

Expand All @@ -1113,14 +1100,14 @@ def _is_layerwise(obj):
args.moe_calib_experts_ratio,
)

enable_quant_kv_cache = args.kv_cache_qformat != "none"
enable_quant_kv_cache = args.kv_cache_qformat != _KV_NONE
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"],
)

# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92).
Expand All @@ -1135,14 +1122,6 @@ def _is_layerwise(obj):
quant_cfg["quant_cfg"].append({"quantizer_name": pattern, "enable": False})
print(f"Excluding MTP layer from quantization: {pattern}")

# Use constant amax for KV quantizers when a cast format is selected.
# Recipes are authoritative for KV cache config (including use_constant_amax),
# so skip this post-hoc override when --recipe is used; rely on the YAML instead
# (see modelopt_recipes/general/ptq/*_cast_kv.yaml).
if args.recipe is None and args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

if needs_checkpoint_path_update(quant_cfg):
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
print(
Expand Down Expand Up @@ -1293,7 +1272,7 @@ def parse_args() -> argparse.Namespace:
"--kv_cache_qformat",
required=False,
default="fp8_cast",
choices=KV_QUANT_CFG_CHOICES.keys(),
choices=[_KV_NONE, *KV_QUANT_CFG_CHOICES],
Comment thread
shengliangxu marked this conversation as resolved.
Outdated
help=(
"Specify KV cache quantization format. Default: fp8_cast. "
"Formats ending in '_cast' (fp8_cast, nvfp4_cast) set the amax to FP8 range "
Expand Down Expand Up @@ -1475,6 +1454,16 @@ def parse_args() -> argparse.Namespace:
if args.specdec_offline_dataset is not None and args.low_memory_mode:
parser.error("--specdec_offline_dataset is not compatible with --low_memory_mode.")

# The low-memory loader pre-instruments quantizers from --qformat/--kv_cache_qformat
# via init_quantized_weights(), so it cannot honor a --recipe (which is authoritative
# for the quant layout in quantize_main). Reject the combination rather than silently
# instrumenting a layout that diverges from the recipe.
if args.low_memory_mode and args.recipe is not None:
parser.error(
"--low_memory_mode does not yet support --recipe; the low-memory loader still "
"initializes quantizers from --qformat/--kv_cache_qformat."
)

return args


Expand Down
Loading
Loading