Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/features/quantization/modelopt.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ following `quantization.quant_algo` values:
- `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization.
- `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks).
- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`).
- `MXFP8`: ModelOpt MXFP8 checkpoints (use `quantization="modelopt_mxfp8"`).

## Quantizing HuggingFace Models with PTQ

Expand Down
1 change: 1 addition & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ def _verify_quantization(self) -> None:
"moe_wna16",
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,13 +477,15 @@ def make(
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
"mxfp8",
}
assert not isinstance(weight_dtype, str) or weight_dtype in {
"nvfp4",
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
"int4",
"mxfp8",
}

if weight_dtype is None:
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"fp_quant",
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"gguf",
"gptq_marlin",
"awq_marlin",
Expand Down Expand Up @@ -119,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .inc import INCConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config
Expand All @@ -133,6 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fp_quant": FPQuantConfig,
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"modelopt_mxfp8": ModelOptMxFp8Config,
"gguf": GGUFConfig,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
Expand Down
249 changes: 247 additions & 2 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
MXFP8_SCALE_DTYPE,
MXFP8_VALUE_DTYPE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
Expand Down Expand Up @@ -103,6 +110,8 @@
"FP8_PB_WO",
# FP4
"NVFP4",
# MXFP8
"MXFP8",
]
KV_CACHE_QUANT_ALGOS = ["FP8"]

Expand Down Expand Up @@ -386,12 +395,12 @@ def override_quantization_method(
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", ""))
if "FP8" in quant_algo.upper():
if quant_algo.upper() == "FP8":
return "modelopt"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
if "FP8" in quant_algo.upper():
if quant_algo.upper() == "FP8":
return "modelopt"

return None
Expand Down Expand Up @@ -1549,3 +1558,239 @@ def apply(
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod


class ModelOptMxFp8Config(ModelOptQuantConfigBase):
"""Config class for ModelOpt MXFP8."""

def __init__(
self,
is_checkpoint_mxfp8_serialized: bool,
kv_cache_quant_algo: str | None,
exclude_modules: list[str],
) -> None:
super().__init__(exclude_modules)
self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized

if not is_checkpoint_mxfp8_serialized:
raise ValueError(
"MXFP8 quantization requires a serialized checkpoint. "
"Dynamic quantization is not supported."
)

logger.warning(
"Detected ModelOpt MXFP8 checkpoint. Please note that "
"the format is experimental and could change in future."
)

self.kv_cache_quant_algo = kv_cache_quant_algo

def get_name(self) -> QuantizationMethods:
return "modelopt_mxfp8"

def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
return 100

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
# MXFP8 does not yet support MoE models
if isinstance(layer, FusedMoE):
raise NotImplementedError(
"MXFP8 quantization does not yet support MoE models. "
"Please use FP8 or NVFP4 quantization for MoE models."
)
return super().get_quant_method(layer, prefix)

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt MXFP8 config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None

# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()

# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None

# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"

return None

@classmethod
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
**kwargs: Any,
) -> "ModelOptMxFp8Config":
is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper()

# For MXFP8, validate required fields in the config
if is_checkpoint_mxfp8_serialized and "quantization" in original_config:
quant_config = original_config["quantization"]
required_fields = ["kv_cache_quant_algo", "exclude_modules"]
missing_fields = [
field for field in required_fields if field not in quant_config
]
if missing_fields:
raise ValueError(
f"MXFP8 quantization requires the following fields in "
f"hf_quant_config.json: {missing_fields}"
)

return cls(
is_checkpoint_mxfp8_serialized,
kv_cache_quant_method,
exclude_modules,
)


class ModelOptMxFp8LinearMethod(LinearMethodBase):
"""Linear method for ModelOpt MXFP8 quantization."""

def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
self.quant_config = quant_config

if not self.quant_config.is_checkpoint_mxfp8_serialized:
raise ValueError(
"MXFP8 currently only supports serialized checkpoints. "
"Dynamic quantization is not supported."
)

backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION
self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend)
logger.info_once("Using %s backend for MXFP8 GEMM", backend.value)

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size

if not self.quant_config.is_checkpoint_mxfp8_serialized:
raise ValueError(
"MXFP8 quantization was selected, but checkpoint is not "
"MXFP8 serialized. Dynamic quantization is not supported."
)

output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
raise ValueError(
f"MXFP8 requires input dimension to be divisible by "
f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}"
)

# Weight tensor: FP8 E4M3 format
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=MXFP8_VALUE_DTYPE,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K
weight_scale = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // MXFP8_BLOCK_SIZE,
dtype=MXFP8_SCALE_DTYPE,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if layer.weight.ndim != 2:
raise ValueError(
f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
f"with shape {tuple(layer.weight.shape)}"
)

if layer.weight.dtype != MXFP8_VALUE_DTYPE:
raise ValueError(
f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), "
f"got {layer.weight.dtype}. The checkpoint may not be properly "
f"quantized with MXFP8."
)

weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE

# Slice weight_scale to match weight dimensions (handles padding)
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()

layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if layer.weight.dtype != MXFP8_VALUE_DTYPE:
raise ValueError(
f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
)
if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"Weight scale dtype {layer.weight_scale.dtype} != "
f"expected {MXFP8_SCALE_DTYPE}"
)

return self.mxfp8_linear_op.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
)


# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
Loading