Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
e21a9ba
feat(opt): make load_config return validated schema instances
shengliangxu May 14, 2026
198a305
feat(quant): make QuantizerCfgEntry a ModeloptBaseConfig pydantic type
shengliangxu May 14, 2026
058234a
refactor(quant): tighten type hints on QuantizeConfig field validators
shengliangxu May 14, 2026
02513b6
refactor(quant): name the legacy flat-dict quant_cfg input shape
shengliangxu May 14, 2026
c855d2a
refactor(quant): widen quant_cfg input types to Mapping/Sequence
shengliangxu May 14, 2026
0b6b2f0
need to have model_dump for explicitly set k/v
shengliangxu May 14, 2026
c7ab593
refactor(recipe): make RecipeMetadataConfig a ModeloptBaseConfig
shengliangxu May 14, 2026
df4ffb4
fix(recipe): require metadata and quantize sections
shengliangxu May 15, 2026
0d087e0
feat(opt): make ModeloptBaseConfig a real MutableMapping
shengliangxu May 15, 2026
77c8e67
fix(quant): normalize empty cfg to None when disabling a quantizer
shengliangxu May 15, 2026
c43c0f0
fix(quant): keep shared cfg snippets as dicts in public constants
shengliangxu May 15, 2026
a63d420
fix(opt): __setitem__ raises KeyError for unknown keys
shengliangxu May 15, 2026
d7b6e0a
test(quant): tighten cfg-shape rejection assertions
shengliangxu May 15, 2026
251ea06
refactor(quant): schematize QuantizerCfgEntry.cfg as QuantizerAttribu…
shengliangxu May 15, 2026
21d2b35
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 15, 2026
0b788e4
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 15, 2026
11ebae0
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 15, 2026
23c9060
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions modelopt/onnx/llm_export_utils/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"):
else:
raise ValueError(f"Unsupported precision: {precision}")

quant_cfg_list: list = [
e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e
]
quant_cfg_list: list = [e for e in quant_cfg["quant_cfg"] if "quantizer_name" in e]

if lm_head_precision == "fp8":
quant_cfg_list.append(
Expand Down
53 changes: 24 additions & 29 deletions modelopt/recipe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@

from enum import Enum

from pydantic import field_validator
from typing_extensions import NotRequired, TypedDict
from pydantic import Field

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
from modelopt.torch.quantization.config import QuantizeConfig
from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001


class RecipeType(str, Enum):
Expand All @@ -33,14 +32,21 @@ class RecipeType(str, Enum):
# QAT = "qat" # Not implemented yet, will be added in the future.


class RecipeMetadataConfig(TypedDict):
"""YAML shape of the recipe metadata section."""
_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe."

recipe_type: RecipeType
description: NotRequired[str]

class RecipeMetadataConfig(ModeloptBaseConfig):
"""YAML shape of the recipe metadata section."""

_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe."
recipe_type: RecipeType = Field(
title="Recipe type",
description="The type of the recipe (e.g. PTQ).",
)
description: str = ModeloptField(
default=_DEFAULT_RECIPE_DESCRIPTION,
title="Description",
description="Human-readable description of the recipe.",
)


class ModelOptRecipeBase(ModeloptBaseConfig):
Expand All @@ -49,41 +55,30 @@ class ModelOptRecipeBase(ModeloptBaseConfig):
If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``.
"""

metadata: RecipeMetadataConfig = ModeloptField(
default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION},
metadata: RecipeMetadataConfig = Field(
title="Metadata",
description="Recipe metadata containing the recipe type and description.",
validate_default=True,
description="Recipe metadata containing the recipe type and description. "
"Required: a recipe without a ``metadata`` section is rejected so that a "
"missing section can't silently fall back to a default recipe type.",
)

@field_validator("metadata")
@classmethod
def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig:
"""Validate recipe metadata and fill defaults for optional fields."""
if metadata["recipe_type"] not in RecipeType:
raise ValueError(
f"Unsupported recipe type: {metadata['recipe_type']}. "
f"Only {list(RecipeType)} are currently supported."
)
return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata}

@property
def recipe_type(self) -> RecipeType:
"""Return the recipe type from metadata."""
return self.metadata["recipe_type"]
return self.metadata.recipe_type

@property
def description(self) -> str:
"""Return the recipe description from metadata."""
return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION)
return self.metadata.description


class ModelOptPTQRecipe(ModelOptRecipeBase):
"""Our config class for PTQ recipes."""

quantize: QuantizeConfig = ModeloptField(
default=QuantizeConfig(),
quantize: QuantizeConfig = Field(
title="PTQ config",
description="PTQ config containing quant_cfg and algorithm.",
validate_default=True,
description="PTQ config containing quant_cfg and algorithm. Required: a PTQ "
"recipe without a ``quantize`` section is rejected so that a missing section "
"can't silently fall back to the default INT8 config.",
)
49 changes: 10 additions & 39 deletions modelopt/recipe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,29 +89,15 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas
The file must contain a ``metadata`` section with at least ``recipe_type``,
plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes.
"""
data = load_config(recipe_file, schema_type=ModelOptPTQRecipe)
if not isinstance(data, dict):
recipe = load_config(recipe_file, schema_type=ModelOptPTQRecipe)
if not isinstance(recipe, ModelOptPTQRecipe):
raise ValueError(
f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}."
f"Recipe file {recipe_file} must produce a {ModelOptPTQRecipe.__name__}, "
f"got {type(recipe).__name__}."
)

metadata = data.get("metadata", {})
if not isinstance(metadata, dict):
raise ValueError(
f"Recipe file {recipe_file} field 'metadata' must be a mapping, "
f"got {type(metadata).__name__}."
)
recipe_type = metadata.get("recipe_type")
if recipe_type is None:
raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.")

recipe_type = recipe.recipe_type
if recipe_type == RecipeType.PTQ:
if "quantize" not in data:
raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.")
return ModelOptPTQRecipe(
metadata=metadata,
quantize=data["quantize"],
)
return recipe
raise ValueError(f"Unsupported recipe type: {recipe_type!r}")


Expand All @@ -137,25 +123,10 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase:
quantize.
"""
metadata_file = _find_recipe_section_file(recipe_dir, "metadata")

metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig)
if not isinstance(metadata, dict):
raise ValueError(
f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}."
)
recipe_type = metadata.get("recipe_type")
if recipe_type is None:
raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.")

if recipe_type == RecipeType.PTQ:
if metadata.recipe_type == RecipeType.PTQ:
quantize_file = _find_recipe_section_file(recipe_dir, "quantize")
quantize_data = load_config(quantize_file, schema_type=QuantizeConfig)
if not isinstance(quantize_data, dict):
raise ValueError(
f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}."
)
return ModelOptPTQRecipe(
metadata=metadata,
quantize=quantize_data,
)
raise ValueError(f"Unsupported recipe type: {recipe_type!r}")
quantize_cfg = load_config(quantize_file, schema_type=QuantizeConfig)
return ModelOptPTQRecipe(metadata=metadata, quantize=quantize_cfg)
raise ValueError(f"Unsupported recipe type: {metadata.recipe_type!r}")
52 changes: 45 additions & 7 deletions modelopt/torch/opt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import fnmatch
import json
from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView
from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView
from typing import Any, TypeAlias

import torch
Expand Down Expand Up @@ -57,11 +57,18 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802
# TODO: expand config classes to searcher


class ModeloptBaseConfig(BaseModel):
class ModeloptBaseConfig(BaseModel, MutableMapping):
"""Our config base class for mode configuration.

The base class extends the capabilities of pydantic's BaseModel to provide additional methods
and properties for easier access and manipulation of the configuration.

Inherits from :class:`collections.abc.MutableMapping` so instances satisfy
``isinstance(cfg, Mapping)`` / ``isinstance(cfg, MutableMapping)`` checks and pick up the
mixin methods (``pop``, ``popitem``, ``setdefault``, ``clear``). Schema fields are fixed,
so ``__delitem__`` raises :class:`TypeError`; the inherited ``pop`` / ``clear`` /
``popitem`` therefore also raise on any existing key, while ``pop(key, default)`` for a
missing key still returns the default normally.
"""

model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True)
Expand Down Expand Up @@ -110,18 +117,49 @@ def __contains__(self, key: str) -> bool:
return False

def __getitem__(self, key: str) -> Any:
"""Get the value for the given key (can be name or alias of field)."""
return getattr(self, self.get_field_name_from_key(key))
"""Get the value for the given key (can be name or alias of field).

Raises :class:`KeyError` for missing keys so the class behaves like a regular
:class:`Mapping` — required for the inherited ``MutableMapping`` mixin methods
(``pop``, ``setdefault``, ...) to dispatch correctly.
"""
try:
return getattr(self, self.get_field_name_from_key(key))
except AttributeError:
raise KeyError(key) from None

def __setitem__(self, key: str, value: Any) -> None:
"""Set the value for the given key (can be name or alias of field)."""
setattr(self, self.get_field_name_from_key(key), value)
"""Set the value for the given key (can be name or alias of field).

Raises :class:`KeyError` (not :class:`AttributeError`) for unknown keys so the
class matches the :class:`MutableMapping` protocol — both for direct
``cfg["unknown"] = value`` writes and for inherited mixin helpers like
``setdefault`` that write through ``__setitem__``.
"""
try:
setattr(self, self.get_field_name_from_key(key), value)
except AttributeError:
raise KeyError(key) from None

def __delitem__(self, key: str) -> None:
"""Reject key deletion.

``ModeloptBaseConfig`` exposes a fixed pydantic schema, so removing a key is
ill-defined: schema fields can't disappear, and silently resetting them to their
defaults would surprise callers. Raise ``TypeError`` instead. Defined so the
class fully satisfies the ``MutableMapping`` protocol (``__delitem__`` is
required), without committing to actual deletion semantics.
"""
raise TypeError(
f"{type(self).__name__} does not support key deletion; schema fields are "
f"fixed (attempted to delete {key!r})."
)

def get(self, key: str, default: Any = None) -> Any:
"""Get the value for the given key (can be name or alias) or default if not found."""
Comment thread
shengliangxu marked this conversation as resolved.
try:
return self[key]
except AttributeError:
except KeyError:
return default

def __len__(self) -> int:
Expand Down
75 changes: 61 additions & 14 deletions modelopt/torch/opt/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@
import re
import sys
from pathlib import Path
from typing import Any, Union, get_args, get_origin, get_type_hints
from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints, overload

import yaml
from pydantic import TypeAdapter
from typing_extensions import NotRequired, Required, is_typeddict

from modelopt.torch.opt.config import ModeloptBaseConfig


@dataclass
class _ListSnippet:
Expand Down Expand Up @@ -592,29 +594,74 @@ def _find_import_marker(obj: Any, context: str = "root") -> tuple[Any, str] | No
return None


_SchemaT = TypeVar("_SchemaT", bound=ModeloptBaseConfig)


@overload
def load_config(
config_path: str | Path | Traversable,
*,
schema_type: type[_SchemaT],
) -> _SchemaT: ...


@overload
def load_config(
config_path: str | Path | Traversable,
*,
schema_type: type[list[_SchemaT]],
) -> list[_SchemaT]: ...


@overload
def load_config(
config_path: str | Path | Traversable,
*,
schema_type: None = None,
) -> Any: ...


def load_config(
config_path: str | Path | Traversable,
*,
schema_type: Any | None = None,
) -> dict[str, Any] | list[Any]:
) -> Any:
"""Load a YAML config and resolve all ``$import`` references.

This is the primary config loading entry point. It loads the YAML file,
resolves any ``imports`` / ``$import`` directives, and returns the final
config dict or list.

``schema_type`` supplies a typing context for import resolution when the
file itself has no ``modelopt-schema`` comment. It is intentionally not a
request to validate the top-level file. Top-level files are validated only
when they declare ``modelopt-schema``; imported snippets are stricter and
must always declare ``modelopt-schema``.
resolves any ``imports`` / ``$import`` directives, and returns either a
validated instance of the schema (when one is known) or the raw resolved
payload.

The effective schema is selected as follows:

1. If ``schema_type`` is provided, it is used.
2. Otherwise, the schema declared by the file's ``# modelopt-schema:``
comment (if any) is used.

When an effective schema is selected, the resolved payload is validated
and returned as an instance of that schema — e.g., a Pydantic model
instance for ``BaseModel`` schemas, or a validated dict / list for
``TypedDict`` / ``list[TypedDict]`` schemas. If neither source supplies a
schema, the raw resolved dict or list is returned unchanged.

Imported snippets are stricter and must always declare ``modelopt-schema``;
they are validated during import resolution regardless of the top-level
selection above.
"""
raw = _load_raw_config_with_schema(config_path)
data = raw.data
declared_schema_type = _schema_type(raw.schema) if raw.schema else None
resolver_schema_type = declared_schema_type or schema_type
effective_schema_type = schema_type if schema_type is not None else declared_schema_type

if isinstance(data, (_ListSnippet, dict)):
data = _resolve_imports(data, schema_type=resolver_schema_type)
_validate_modelopt_schema(raw.schema, data, raw.path, schema_type=declared_schema_type)
return data
data = _resolve_imports(data, schema_type=effective_schema_type)
if effective_schema_type is None:
return data
try:
return TypeAdapter(effective_schema_type).validate_python(data)
except Exception as exc:
raise ValueError(
f"Config file {raw.path} does not match modelopt-schema "
f"{_schema_label(effective_schema_type, raw.schema)!r}: {exc}"
) from exc
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from . import config as mtq_config
from . import model_calib
from .config import QuantizeConfig, QuantizerAttributeConfig
from .config import QuantizeConfig, QuantizerAttributeConfig, QuantizerCfgEntry
from .conversion import set_quantizer_by_cfg
from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import is_quantized_linear
Expand Down Expand Up @@ -129,7 +129,9 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No
# Disable KV Cache quantization
# Currently KV Cache quantization is enabled for some quantization formats and disabled for others
# This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy
self.config.quant_cfg.append({"quantizer_name": "*output_quantizer", "enable": False})
self.config.quant_cfg.append(
QuantizerCfgEntry(quantizer_name="*output_quantizer", enable=False)
)

self.compression = estimate_quant_compression(self.config)

Expand Down
Loading
Loading