Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f453f92
yaml for all hard coded PTQ configs
shengliangxu May 6, 2026
965cbf4
numerics yaml
shengliangxu May 7, 2026
5d373b5
Remove quantize config loader wrapper
shengliangxu May 8, 2026
60deb60
Add KV quantization config units
shengliangxu May 8, 2026
a971c75
Remove stale FP8 config comments
shengliangxu May 9, 2026
d4a9df6
update int4 int8
shengliangxu May 9, 2026
d834078
update descriptions
shengliangxu May 9, 2026
37f45ec
Move Diffusers quant configs to YAML
shengliangxu May 15, 2026
ca33e92
Use numerics imports in YAML configs
shengliangxu May 15, 2026
683b0ef
Factor shared quant_cfg blocks into reusable YAML units
shengliangxu May 15, 2026
fa80c0e
Move W4A16_NVFP4_CFG to YAML
shengliangxu May 18, 2026
296934a
Address review feedback from PR #1423
shengliangxu May 18, 2026
7463504
Restore metadata blocks in speculative-decoding recipe YAMLs
shengliangxu May 18, 2026
51c8e5f
Keep *_CFG constants as plain dicts; simplify cfg-list loader
shengliangxu May 18, 2026
1e0071b
Merge branch 'main' into shengliangx/all-yaml-configs
shengliangxu May 18, 2026
f07d549
Fix llm_autodeploy SUPPORT_QUANT_FORMAT type annotation
shengliangxu May 18, 2026
2af92ac
Merge branch 'main' into shengliangx/all-yaml-configs
shengliangxu May 19, 2026
e9efded
Revert SmoothQuant gate in diffusers quantize.py
shengliangxu May 19, 2026
b0e9d65
Merge branch 'main' into shengliangx/all-yaml-configs
shengliangxu May 19, 2026
b92ee30
Drop dead model_dump shim in diffusers quantize.py
shengliangxu May 19, 2026
23dad34
Merge branch 'main' into shengliangx/all-yaml-configs
shengliangxu May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 14 additions & 75 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,82 +16,21 @@
import torch.nn as nn
from calib.plugin_calib import PercentileCalibrator

FP8_DEFAULT_CONFIG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
{"quantizer_name": "*input_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
{"quantizer_name": "*output_quantizer", "enable": False},
{"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
],
"algorithm": "max",
}
from modelopt.torch.opt.config_loader import load_config
from modelopt.torch.quantization.config import QuantizeConfig

INT8_DEFAULT_CONFIG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}},
{"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}},
{"quantizer_name": "*output_quantizer", "enable": False},
],
"algorithm": "max",
}

NVFP4_DEFAULT_CONFIG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*weight_quantizer",
"cfg": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
},
"enable": True,
},
{
"quantizer_name": "*input_quantizer",
"cfg": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
},
"enable": True,
},
{"quantizer_name": "*output_quantizer", "enable": False},
{"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
],
"algorithm": "max",
}

NVFP4_FP8_MHA_CONFIG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "**weight_quantizer",
"cfg": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
},
"enable": True,
},
{
"quantizer_name": "**input_quantizer",
"cfg": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
},
"enable": True,
},
{"quantizer_name": "*output_quantizer", "enable": False},
{"quantizer_name": "*[qkv]_bmm_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
{"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
{"quantizer_name": "*bmm2_output_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}},
],
"algorithm": {"method": "svdquant", "lowrank": 32},
}
FP8_DEFAULT_CONFIG = load_config(
"configs/ptq/presets/diffusers/fp8", schema_type=QuantizeConfig
).model_dump(exclude_unset=True)
Comment thread
shengliangxu marked this conversation as resolved.
INT8_DEFAULT_CONFIG = load_config(
"configs/ptq/presets/diffusers/int8", schema_type=QuantizeConfig
).model_dump(exclude_unset=True)
NVFP4_DEFAULT_CONFIG = load_config(
"configs/ptq/presets/diffusers/nvfp4", schema_type=QuantizeConfig
).model_dump(exclude_unset=True)
NVFP4_FP8_MHA_CONFIG = load_config(
"configs/ptq/presets/diffusers/nvfp4_fp8_mha", schema_type=QuantizeConfig
).model_dump(exclude_unset=True)


def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, **kwargs):
Expand Down
26 changes: 17 additions & 9 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import argparse
import copy
import logging
import sys
import time as time
Expand Down Expand Up @@ -114,19 +115,13 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
"""
self.logger.info(f"Building quantization config for {self.config.format.value}")

apply_int8_percentile_calibrator = False
if self.config.format == QuantFormat.INT8:
if self.config.algo == QuantAlgo.SMOOTHQUANT:
base_cfg = mtq.INT8_SMOOTHQUANT_CFG
else:
base_cfg = INT8_DEFAULT_CONFIG
if self.config.collect_method != CollectMethod.DEFAULT:
reset_set_int8_config(
base_cfg,
self.config.percentile,
n_steps,
collect_method=self.config.collect_method.value,
backbone=backbone,
)
apply_int8_percentile_calibrator = self.config.collect_method != CollectMethod.DEFAULT
elif self.config.format == QuantFormat.FP8:
Comment thread
shengliangxu marked this conversation as resolved.
base_cfg = FP8_DEFAULT_CONFIG
elif self.config.format == QuantFormat.FP4:
Expand All @@ -137,7 +132,20 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
else:
raise NotImplementedError(f"Unknown format {self.config.format}")

# Build a fresh config dict so we never mutate the global constants.
# Build a fresh config dict so runtime overrides never mutate the global constants.
base_cfg = copy.deepcopy(base_cfg)
if hasattr(base_cfg, "model_dump"):
base_cfg = base_cfg.model_dump(exclude_unset=True)

if apply_int8_percentile_calibrator:
reset_set_int8_config(
base_cfg,
self.config.percentile,
n_steps,
collect_method=self.config.collect_method.value,
backbone=backbone,
)

quant_cfg_list = list(base_cfg["quant_cfg"])

if self.config.format == QuantFormat.FP4:
Expand Down
5 changes: 3 additions & 2 deletions examples/llm_autodeploy/run_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import argparse
from collections import defaultdict
from typing import Any

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -24,7 +25,7 @@
from modelopt.torch.utils import create_forward_loop
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader

SUPPORT_QUANT_FORMAT = {
SUPPORT_QUANT_FORMAT: dict[str, dict[str, Any]] = {
"fp8": mtq.FP8_DEFAULT_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
}
Expand Down Expand Up @@ -87,7 +88,7 @@ def loss_func(output, data):
data_loader=calib_dataloader,
forward_step=lambda model, batch: model(**batch),
loss_func=loss_func,
quantization_formats=[SUPPORT_QUANT_FORMAT[format] for format in qformat_list],
quantization_formats=[SUPPORT_QUANT_FORMAT[quant_format] for quant_format in qformat_list],
num_calib_steps=len(calib_dataloader),
num_score_steps=min(
len(calib_dataloader), 128 // batch_size
Expand Down
20 changes: 19 additions & 1 deletion modelopt/torch/opt/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,19 @@ def _schema_equal(left: Any | None, right: Any | None) -> bool:
def _list_element_schema(schema_type: Any | None) -> Any | None:
"""Return the element schema for a typed ``list[T]`` annotation."""
schema_type = _unwrap_schema_type(schema_type)
if get_origin(schema_type) is not list:
origin = get_origin(schema_type)
if origin in (UnionType, Union):
element_schemas = []
for arg in get_args(schema_type):
if arg is NoneType:
continue
element_schema = _list_element_schema(arg)
if element_schema is None:
continue
if not any(_schema_equal(element_schema, seen) for seen in element_schemas):
element_schemas.append(element_schema)
return element_schemas[0] if len(element_schemas) == 1 else None
if origin is not list:
return None
args = get_args(schema_type)
if len(args) != 1 or args[0] is Any:
Expand Down Expand Up @@ -510,6 +522,12 @@ def _resolve_list_import(
if _schema_equal(imported.schema_type, element_schema):
return [imported.data]

element_schema_unwrapped = _unwrap_schema_type(element_schema)
if isinstance(imported.data, dict) and (
element_schema_unwrapped is dict or get_origin(element_schema_unwrapped) is dict
):
return [imported.data]

raise ValueError(
f"$import {ref_name!r} in list at {context} has schema "
f"{_schema_label(imported.schema_type, imported.schema)!r}; expected either "
Expand Down
Loading
Loading