Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Changelog
``ModelOptArgParser`` adds ``--config`` YAML support with CLI overrides and auto-generates ``ARGUMENTS.md`` from dataclass definitions.
Dataset blending (``configs/dataset/blend.yaml``) supports HuggingFace datasets, local JSON/JSONL/Parquet files, and weighted multi-source blends.
The legacy FSDP1 accelerate config is removed; ``llm_qat`` now documents FSDP2, DeepSpeed, and DDP backends.
- The PTQ example scripts ``examples/llm_ptq/hf_ptq.py``, ``examples/llm_ptq/multinode_ptq.py`` and ``examples/megatron_bridge/quantize.py`` now derive their ``--qformat`` / ``--kv_cache_qformat`` (``--quant_cfg`` / ``--kv_cache_quant`` for Megatron-Bridge) CLI vocabularies by discovering the YAML presets under ``modelopt_recipes/configs/ptq/presets/{model,kv}/`` rather than carrying hardcoded ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` tables. The discovery helper, alias table and ready-built ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` mappings now live in ``modelopt.recipe.presets`` and are shared by all three scripts. Presets are loaded eagerly into a plain dict at import. Adding a new preset YAML makes it available on the CLI of all three with no script change — note this means each script now accepts every preset under those directories, not just a previously curated subset. All previously-supported short names (``int8_sq``, ``nvfp4_awq``, ``fp8_pb_wo``, ``nvfp4_mse``, ``w4a8_awq``, ``nvfp4_local_hessian``, ``fp8_pc_pt``, ``int8_wo``) keep working via a small deprecation alias table; new formats should be exposed as preset YAMLs (or, longer term, as full ``--recipe`` recipes).
- Add ``configs/ptq/presets/kv/fp8_cast.yaml`` and ``configs/ptq/presets/kv/nvfp4_cast.yaml``, promoting ``fp8_cast`` / ``nvfp4_cast`` to first-class KV presets composed from the existing ``kv_fp8_cast`` / ``kv_nvfp4_cast`` unit fragments. The previous runtime ``use_constant_amax`` post-edit in ``hf_ptq.py`` is removed; ``use_constant_amax: true`` now lives in the YAML and is therefore authoritative. **Custom (out-of-tree) recipes that target a cast KV format must set ``use_constant_amax: true`` themselves on the ``[kv]_bmm_quantizer`` config** — in-tree recipes already do via the ``kv_*_cast`` units.

**Bug Fixes**

Expand Down
187 changes: 89 additions & 98 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
from modelopt.recipe.presets import (
KV_CACHE_NONE,
KV_QUANT_CFG_CHOICES,
QFORMAT_ALIASES,
QUANT_CFG_CHOICES,
)
from modelopt.torch.export import (
export_hf_checkpoint,
export_hf_vllm_fq_checkpoint,
Expand Down Expand Up @@ -86,56 +92,67 @@
RAND_SEED = 1234


def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"""Set use_constant_amax on KV cache quantizers.
def _kv_cfg_uses_constant_amax(kv_quant_cfg: list[dict[str, Any]]) -> bool:
"""Return True if this KV cfg pins ``use_constant_amax`` on the bmm quantizer.

Creates a new dict for the KV bmm quantizer config to avoid mutating shared references.
Cast-style KV presets (e.g. ``fp8_cast`` / ``nvfp4_cast``) set
``use_constant_amax: true`` on the ``*[kv]_bmm_quantizer`` entry; that flag
means there is no data-driven calibration to run, so callers should skip
the KV-only calibration pass. Detect the property from the YAML contents
rather than from the preset name so new cast-style presets work
automatically.
"""
for i, entry in enumerate(quant_cfg):
for entry in kv_quant_cfg:
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
continue
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
break


QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a16_nvfp4": mtq.W4A16_NVFP4_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,
}

KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8_cast": "FP8_KV_CFG",
"fp8": "FP8_KV_CFG",
"fp8_affine": "FP8_AFFINE_KV_CFG",
"nvfp4_cast": "NVFP4_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
}

# Formats that use use_constant_amax (no calibration needed).
_KV_CAST_FORMATS = {"fp8_cast", "nvfp4_cast"}
return bool(cfg.get("use_constant_amax"))
return False


# Formats supported by mtq.auto_quantize unified-checkpoint export.
#
# This stays hardcoded — and intentionally not derived from the preset directory —
# because auto_quantize compatibility is a property of the export path (the unified
# HF checkpoint writer, TRT-LLM consumer constraints, layer-wise mixing rules), not
# of the YAML itself. A preset can exist and be valid for plain PTQ while not being
# safe to mix into an auto_quantize search. Update this set when adding/removing a
# format from auto_quantize support.
#
# NOTE: auto_quantize is being refactored/reimplemented; this table and the
# _canonical_qformat helper below are expected to be removed in the near future, so
# deliberately not invested in deriving them from the presets.
_AUTO_QUANTIZE_QFORMATS: frozenset[str] = frozenset(
{
"fp8",
"int8_smoothquant",
"int8_weight_only",
"int4_awq",
"nvfp4",
"nvfp4_awq_lite",
"nvfp4_w4a4_weight_mse_fp8_sweep",
"w4a8_awq_beta",
"fp8_2d_blockwise_weight_only",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_w4a4_weight_local_hessian",
"mxfp8",
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)


def _canonical_qformat(name: str) -> str:
"""Resolve a user-provided qformat token to its canonical preset basename.

Lets membership checks (e.g. against :data:`_AUTO_QUANTIZE_QFORMATS`) accept
either the short alias (``int8_sq``) or the canonical YAML basename
(``int8_smoothquant``). Unknown tokens pass through unchanged so the existing
error paths still fire.
"""
return QFORMAT_ALIASES.get(name, name)


mto.enable_huggingface_checkpointing()

Expand Down Expand Up @@ -311,27 +328,11 @@ def auto_quantize(

qformat_list = args.qformat.split(",")
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
# Check if all provided quantization formats are supported. Canonicalize first so
# callers may pass either the short alias (``int8_sq``) or the canonical YAML
# basename (``int8_smoothquant``).
assert all(
qformat
in [
"fp8",
"int8_sq",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_local_hessian",
"mxfp8",
]
for qformat in qformat_list
_canonical_qformat(qformat) in _AUTO_QUANTIZE_QFORMATS for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

# When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
Expand Down Expand Up @@ -417,21 +418,16 @@ def forward_step(model, batch):

calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
# We need to explicitly set up KV cache quantization after auto_quantize
enable_quant_kv_cache = args.kv_cache_qformat != "none"
enable_quant_kv_cache = args.kv_cache_qformat != KV_CACHE_NONE
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
if enable_quant_kv_cache:
kv_cache_quant_cfg = copy.deepcopy(
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)
kv_cache_quant_cfg = copy.deepcopy(KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"])
kv_cache_quant_cfg = [
e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*"
] # keep other quantizers from auto_quantize

if args.kv_cache_qformat in _KV_CAST_FORMATS:
_set_kv_cache_constant_amax(kv_cache_quant_cfg)

mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
if args.kv_cache_qformat not in _KV_CAST_FORMATS:
if not _kv_cfg_uses_constant_amax(kv_cache_quant_cfg):
# Calibrate only the KV cache quantizers; disable all others.
with mtq.set_quantizer_by_cfg_context(
language_model,
Expand All @@ -455,21 +451,14 @@ def load_model(args: argparse.Namespace):
)
else:
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
f"Quantization format is not supported for low memory mode. Supported formats: {list(QUANT_CFG_CHOICES)}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
if args.kv_cache_qformat != "none":
if args.kv_cache_qformat != KV_CACHE_NONE:
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"],
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Mirror the use_constant_amax logic from quantize_main so that init_quantized_weights
# builds the KV quantizers with use_constant_amax already set. In calibration_only mode
# mtq.calibrate() does not re-apply quant_cfg, so this must happen before
# init_quantized_weights runs.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

# Do not use real quant GEMM so the calibration can be more accurate.
with init_quantized_weights(
Expand Down Expand Up @@ -1103,7 +1092,7 @@ def _is_layerwise(obj):
)

assert args.qformat in QUANT_CFG_CHOICES, (
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES)}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]

Expand All @@ -1113,14 +1102,14 @@ def _is_layerwise(obj):
args.moe_calib_experts_ratio,
)

enable_quant_kv_cache = args.kv_cache_qformat != "none"
enable_quant_kv_cache = args.kv_cache_qformat != KV_CACHE_NONE
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"],
)

# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92).
Expand All @@ -1135,14 +1124,6 @@ def _is_layerwise(obj):
quant_cfg["quant_cfg"].append({"quantizer_name": pattern, "enable": False})
print(f"Excluding MTP layer from quantization: {pattern}")

# Use constant amax for KV quantizers when a cast format is selected.
# Recipes are authoritative for KV cache config (including use_constant_amax),
# so skip this post-hoc override when --recipe is used; rely on the YAML instead
# (see modelopt_recipes/general/ptq/*_cast_kv.yaml).
if args.recipe is None and args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

if needs_checkpoint_path_update(quant_cfg):
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
print(
Expand Down Expand Up @@ -1293,12 +1274,12 @@ def parse_args() -> argparse.Namespace:
"--kv_cache_qformat",
required=False,
default="fp8_cast",
choices=KV_QUANT_CFG_CHOICES.keys(),
choices=[KV_CACHE_NONE, *KV_QUANT_CFG_CHOICES],
help=(
"Specify KV cache quantization format. Default: fp8_cast. "
"Formats ending in '_cast' (fp8_cast, nvfp4_cast) set the amax to FP8 range "
"without data-driven calibration. "
"Other formats (fp8, nvfp4, etc.) use data-driven calibration. "
"Formats whose preset pins use_constant_amax on the KV bmm quantizer "
"(e.g. fp8_cast, nvfp4_cast) set the amax to FP8 range without data-driven "
"calibration; all other formats (fp8, nvfp4, ...) use data-driven calibration. "
"Ignored when --recipe is given: the recipe YAML is authoritative for KV "
"cache config (use the *_cast_kv.yaml recipes for the cast variants)."
),
Expand Down Expand Up @@ -1475,6 +1456,16 @@ def parse_args() -> argparse.Namespace:
if args.specdec_offline_dataset is not None and args.low_memory_mode:
parser.error("--specdec_offline_dataset is not compatible with --low_memory_mode.")

# The low-memory loader pre-instruments quantizers from --qformat/--kv_cache_qformat
# via init_quantized_weights(), so it cannot honor a --recipe (which is authoritative
# for the quant layout in quantize_main). Reject the combination rather than silently
# instrumenting a layout that diverges from the recipe.
if args.low_memory_mode and args.recipe is not None:
parser.error(
"--low_memory_mode does not yet support --recipe; the low-memory loader still "
"initializes quantizers from --qformat/--kv_cache_qformat."
)

return args


Expand Down
31 changes: 6 additions & 25 deletions examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time
import warnings
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand All @@ -34,6 +33,7 @@

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.recipe.presets import KV_CACHE_NONE, KV_QUANT_CFG_CHOICES, QUANT_CFG_CHOICES
from modelopt.torch.export import get_model_type
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
Expand All @@ -44,25 +44,6 @@
# Constants
RAND_SEED = 1234

QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
}

KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8": "FP8_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
}


# Enable HuggingFace checkpointing
mto.enable_huggingface_checkpointing()
Expand All @@ -80,13 +61,13 @@ def parse_args():
parser.add_argument(
"--qformat",
default="fp8",
choices=QUANT_CFG_CHOICES.keys(),
choices=list(QUANT_CFG_CHOICES),
help="Quantization format",
)
parser.add_argument(
"--kv_cache_qformat",
default="fp8",
choices=list(KV_QUANT_CFG_CHOICES.keys()),
choices=[KV_CACHE_NONE, *KV_QUANT_CFG_CHOICES],
help="KV cache quantization format",
)
parser.add_argument(
Expand Down Expand Up @@ -280,7 +261,7 @@ def main(args):
# Validate quantization format
if args.qformat not in QUANT_CFG_CHOICES:
raise ValueError(
f"Quantization format {args.qformat} not supported. Choose from: {QUANT_CFG_CHOICES.keys()}"
f"Quantization format {args.qformat} not supported. Choose from: {list(QUANT_CFG_CHOICES)}"
)

# Set random seeds
Expand Down Expand Up @@ -334,14 +315,14 @@ def main(args):
args.awq_block_size,
)

enable_quant_kv_cache = args.kv_cache_qformat != "none"
enable_quant_kv_cache = args.kv_cache_qformat != KV_CACHE_NONE
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"],
)

# Quantize the model
Expand Down
Loading
Loading