From a7da25342fa3f8d0cc9615763e6b5af9c3df65f5 Mon Sep 17 00:00:00 2001 From: MerkyorLynn <268568828+MerkyorLynn@users.noreply.github.com> Date: Sat, 6 Jun 2026 00:59:35 +0800 Subject: [PATCH] Document SM120 ModelOpt NVFP4 Marlin path Signed-off-by: MerkyorLynn <268568828+MerkyorLynn@users.noreply.github.com> --- docs/features/quantization/modelopt.md | 32 ++++++++++++++++++- .../kernels/linear/nvfp4/marlin.py | 8 ++--- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/docs/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md index ad417bcb30ae..f4c65090abb8 100644 --- a/docs/features/quantization/modelopt.md +++ b/docs/features/quantization/modelopt.md @@ -16,7 +16,9 @@ following `quantization.quant_algo` values: - `FP8`: per-tensor weight scale (+ optional static activation scale). - `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization. - `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks). -- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`). +- `NVFP4`: ModelOpt W4A4 NVFP4 checkpoints (use `quantization="modelopt_fp4"`). +- `W4A16_NVFP4`: ModelOpt weight-only NVFP4 checkpoints with fp16/bf16 activations. +- `MIXED_PRECISION`: per-layer ModelOpt checkpoints that combine the formats above, for example FP8 attention layers with W4A16 NVFP4 MoE experts. - `MXFP8`: ModelOpt MXFP8 checkpoints (use `quantization="modelopt_mxfp8"`). ## Quantizing HuggingFace Models with PTQ @@ -102,6 +104,34 @@ vllm serve \ --host 0.0.0.0 --port 8000 ``` +## Serving W4A16 NVFP4 MoE checkpoints with Marlin + +Some ModelOpt NVFP4 MoE checkpoints are exported as +`quantization.quant_algo = "MIXED_PRECISION"` and mark MoE expert layers (and +sometimes `lm_head`) as `W4A16_NVFP4` in `hf_quant_config.json`. This is a +weight-only NVFP4 format: weights are stored in 4-bit NVFP4, while activations +remain fp16/bf16. It is served by the Marlin W4A16 path, not by W4A4 kernels +that expect runtime activation quantization. + +For reproducible debugging and benchmarking of W4A16 NVFP4 checkpoints on +CUDA GPUs where Marlin FP4 is available, you can explicitly pin the Marlin +linear and MoE backends: + +```bash +vllm serve nvidia/Qwen3.6-35B-A3B-NVFP4 \ + --quantization modelopt \ + --linear-backend marlin \ + --moe-backend marlin \ + --kv-cache-dtype fp8_e4m3 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --host 0.0.0.0 --port 8000 +``` + +When debugging startup, check the logs for the Marlin NVFP4 linear and MoE +backend selections. Also run a short generation sanity check before comparing +latency or throughput. + ## Testing (local checkpoints) vLLM's ModelOpt unit tests are gated by local checkpoint paths and are skipped diff --git a/vllm/model_executor/kernels/linear/nvfp4/marlin.py b/vllm/model_executor/kernels/linear/nvfp4/marlin.py index a05d6823c881..f3e71d2457b9 100644 --- a/vllm/model_executor/kernels/linear/nvfp4/marlin.py +++ b/vllm/model_executor/kernels/linear/nvfp4/marlin.py @@ -32,10 +32,10 @@ def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: logger.warning_once( - "Your GPU does not have native support for FP4 computation but " - "FP4 quantization is being used. Weight-only FP4 compression " - "will be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads." + "Using Marlin for NVFP4 weight-only GEMM (W4A16). Activations " + "remain fp16/bf16 on this path; W4A4 NVFP4 checkpoints that " + "quantize activations should use a native NVFP4 backend when " + "available." ) prepare_fp4_layer_for_marlin(layer)