diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 62f2b0041cb..8e5b8b37a22 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -30,6 +30,7 @@ Changelog **New Features** +- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. - Add end-to-end tutorial for Minitron pruning + distillation + quantization + evaluation + vLLM deployment for Nemotron-Nano-9B-v2 → Pruned 7B along with data blend preparation steps (and ablation study). See `examples/pruning/minitron/README.md `_ for details. - Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md `_ for more details. diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 875e78ceea6..d52c3ee40bb 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -113,6 +113,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, "fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG, "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, + "w4a16_nvfp4": mtq.W4A16_NVFP4_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, @@ -785,6 +786,12 @@ def export_quantized( extra_state_dict=mtp_state_dict, ) + if args.qformat == "w4a16_nvfp4": + warnings.warn( + "TensorRT-LLM and SGLang do not support this format. " + "vLLM deployment support is in progress." + ) + # Restore default padding and export the tokenizer as well. if tokenizer is not None: tokenizer.padding_side = default_padding_side @@ -1147,7 +1154,7 @@ def _is_layerwise(obj): quant_cfg = copy.deepcopy(quant_cfg) force_weight_quantizers_static(quant_cfg["quant_cfg"]) - if args.qformat in QUANT_CFG_CHOICES: + if quant_cfg: mono_quantize( args, quant_cfg, diff --git a/modelopt/torch/export/convert_hf_config.py b/modelopt/torch/export/convert_hf_config.py index 5f8c3f3b55c..06e5923a30f 100644 --- a/modelopt/torch/export/convert_hf_config.py +++ b/modelopt/torch/export/convert_hf_config.py @@ -57,6 +57,11 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None) return { "weights": {"dynamic": False, "num_bits": 4, "type": "int", "group_size": gs}, } + elif quant_algo == "W4A16_NVFP4": + gs = group_size or 16 + return { + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs}, + } elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"): gs = group_size or 128 return { @@ -183,6 +188,14 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An "targets": ["Linear"], } new_config["config_groups"] = {"group_0": config_group_details} + elif quant_algo_value == "W4A16_NVFP4": + # Weight-only FP4 + group_size = original_quantization_details.get("group_size", 16) + config_group_details = { + "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size}, + "targets": ["Linear"], + } + new_config["config_groups"] = {"group_0": config_group_details} elif quant_algo_value == "MIXED_PRECISION": quantized_layers = original_quantization_details.get("quantized_layers", {}) diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index dce39767c76..5f92cc2e5dc 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -38,6 +38,7 @@ QUANTIZATION_MXFP4 = "mxfp4" QUANTIZATION_MXFP8 = "mxfp8" QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8" +QUANTIZATION_W4A16_NVFP4 = "w4a16_nvfp4" QUANTIZATION_NVFP4_AWQ = "nvfp4_awq" QUANTIZATION_FP8_PB_REAL = "fp8_pb_real" QUANTIZATION_FP8_PB_WO = "fp8_pb_wo" diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 76f304a478a..cf9f26d51a7 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -69,6 +69,7 @@ QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_W4A8_NVFP4_FP8, + QUANTIZATION_W4A16_NVFP4, ) logger = logging.getLogger(__name__) @@ -359,6 +360,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, ]: # Calibrate weight quantizer if amax is not set @@ -403,6 +405,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, ]: # Calibrate weight quantizer if amax is not set @@ -641,6 +644,9 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames return QUANTIZATION_NVFP4_AWQ if getattr(layer, "fused_with_prequant", False): return QUANTIZATION_NVFP4_AWQ + if input_quantizer is None or not input_quantizer.is_enabled: + if scale_bits == (4, 3): + return QUANTIZATION_W4A16_NVFP4 assert input_quantizer is not None, ( f"input_quantizer is None for {quantizer_attr_names}" ) @@ -808,6 +814,11 @@ def process_layer_quant_config(layer_config_dict): "quant_algo": "NVFP4", "group_size": block_size_value, } + elif v == "w4a16_nvfp4": + layer_config = { + "quant_algo": "W4A16_NVFP4", + "group_size": block_size_value, + } elif v == "nvfp4_awq": layer_config = { "quant_algo": "NVFP4_AWQ", @@ -985,6 +996,7 @@ def to_quantized_weight( if quantization in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_NVFP4_FP8, QUANTIZATION_NVFP4_SVDQUANT, ]: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 73ae63a5a56..0626d0a8fd5 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -83,6 +83,7 @@ QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, + QUANTIZATION_W4A16_NVFP4, ) from .model_utils import get_language_model_from_vl, is_multimodal_model from .moe_utils import _export_fused_experts @@ -521,6 +522,7 @@ def _export_quantized_weight( QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_NVFP4, + QUANTIZATION_W4A16_NVFP4, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ]: @@ -550,6 +552,7 @@ def _export_quantized_weight( QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A16_NVFP4, ]: # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim) # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index dfed54cc991..b450eb5fa0d 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1684,6 +1684,7 @@ def _nvfp4_selective_quant_cfg( ], "algorithm": "max", } +W4A16_NVFP4_CFG = _nvfp4_selective_quant_cfg(["*"], weight_only=True) MXFP4_MLP_WEIGHT_ONLY_CFG = { "quant_cfg": [ @@ -1740,6 +1741,7 @@ def _nvfp4_selective_quant_cfg( "NVFP4_FP8_MHA_CONFIG", "NVFP4_KV_CFG", "NVFP4_KV_ROTATE_CFG", + "W4A16_NVFP4_CFG", "W4A8_NVFP4_FP8_CFG", "NVFP4_SVDQUANT_DEFAULT_CFG", "W4A8_AWQ_BETA_CFG", diff --git a/modelopt_recipes/configs/ptq/units/w4_nvfp4.yaml b/modelopt_recipes/configs/ptq/units/w4_nvfp4.yaml new file mode 100644 index 00000000000..b4676dbff34 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/w4_nvfp4.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# W4A16 NVFP4: NVFP4 E2M1 dynamic weight quantizer only; activations remain in BF16. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + nvfp4: configs/numerics/nvfp4 +--- + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 diff --git a/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml new file mode 100644 index 00000000000..416572e0f80 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + w4a16_nvfp4: configs/ptq/units/w4_nvfp4 + +metadata: + recipe_type: ptq + description: NVFP4 W4A16 weight-only, BF16 activations, max calibration. No calibration forward pass required. +quantize: + algorithm: max + quant_cfg: + - $import: base_disable_all + - $import: w4a16_nvfp4 + - $import: default_disabled_quantizers diff --git a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py index 8bdf3f5e659..6e0c56bfd1d 100644 --- a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py +++ b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py @@ -47,6 +47,7 @@ ("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True, False), ("int8_wo", "tiny_llama-int8-wo", False, False, False, False, False), ("nvfp4_svdquant", "tiny_llama-nvfp4-svdquant", True, False, True, True, True), + ("w4a16_nvfp4", "tiny_llama-w4a16-nvfp4", False, False, False, False, False), # MoE models (fused experts: Qwen3 MoE, GPT-OSS) ("nvfp4", "tiny_qwen3_moe-nvfp4", True, False, True, True, False), ("fp8", "tiny_gpt_oss-fp8", True, False, True, True, False),