Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2349099
mxfp6 support
fxmarty-amd Jul 10, 2025
d538659
refactor mxfp4 to accomodate mxfp6
fxmarty-amd Jul 10, 2025
49b2bbe
add todo
fxmarty-amd Jul 10, 2025
a016c18
rename mxfp4_utils to ocp_mx_utils and add fp6 dequant function
fxmarty-amd Jul 11, 2025
a5835ac
use str instead of enum as torch.library infer_schema does not suppor…
fxmarty-amd Jul 11, 2025
08da3e0
fix a few remaining bugs
fxmarty-amd Jul 11, 2025
43f0ae8
simulate on mi350 as well for now
fxmarty-amd Jul 11, 2025
c537c76
fix e2m3/e3m2 bug
fxmarty-amd Jul 15, 2025
13dea3e
Merge branch 'main' into mxfp6_mixed
fxmarty-amd Jul 22, 2025
a37ef27
wip update tests
fxmarty-amd Jul 22, 2025
de19714
Merge branch 'main' into mxfp6_mixed
fxmarty-amd Jul 22, 2025
b106df5
Merge branch 'main' into mxfp6_mixed
fxmarty-amd Jul 23, 2025
f07c3a8
update tests
fxmarty-amd Jul 23, 2025
b9f9124
update documentation
fxmarty-amd Jul 23, 2025
9c1a90f
address review comments
fxmarty-amd Jul 23, 2025
e4aa06e
linting
fxmarty-amd Jul 23, 2025
1e53ab9
linting 2
fxmarty-amd Jul 23, 2025
5bfa8cb
linting 3
fxmarty-amd Jul 23, 2025
33e431f
linting 4... if only mypy would run locally
fxmarty-amd Jul 23, 2025
bccb5e3
Merge branch 'main' into mxfp6_mixed_updated
fxmarty-amd Aug 6, 2025
bad17cc
undo current_platform.supports_mx() change, moved to standalone #22355
fxmarty-amd Aug 6, 2025
7bbfbc7
Merge branch 'main' into mxfp6_mixed
fxmarty-amd Sep 4, 2025
ef895ca
post merge fixes
fxmarty-amd Sep 4, 2025
160d5c3
fix issues
fxmarty-amd Sep 4, 2025
e53c5c7
edit reference
fxmarty-amd Sep 4, 2025
df3d964
linting
fxmarty-amd Sep 4, 2025
5024d70
linting
fxmarty-amd Sep 4, 2025
28473a8
address comments
fxmarty-amd Sep 6, 2025
4291d3a
fix mxfp4/fp4 typos
fxmarty-amd Sep 6, 2025
4ed70f6
Merge branch 'main' into mxfp6_mixed
fxmarty-amd Sep 24, 2025
37326f0
post-merge cleanup
fxmarty-amd Sep 24, 2025
d2ef885
linting & fixes
fxmarty-amd Sep 24, 2025
3b7260f
cleanup
fxmarty-amd Sep 24, 2025
bdb3706
disable check_model test as it is not working well with v1
fxmarty-amd Sep 24, 2025
f374514
linting
fxmarty-amd Sep 24, 2025
868d5a9
linting
fxmarty-amd Sep 24, 2025
28aa39c
address review comments
fxmarty-amd Sep 25, 2025
b957668
Merge branch 'main' into mxfp6_mixed
fxmarty-amd Oct 2, 2025
397722d
post merge fixes
fxmarty-amd Oct 2, 2025
cebca37
reset files
fxmarty-amd Oct 6, 2025
8ef7e00
lint
fxmarty-amd Oct 6, 2025
996ddd9
Merge branch 'main' into mxfp6_mixed
fxmarty-amd Oct 6, 2025
76e6ee7
prefix with 'mx' everywhere as suggested
fxmarty-amd Oct 6, 2025
28b995b
fix remaining issues
fxmarty-amd Oct 6, 2025
dc246fa
linting
fxmarty-amd Oct 6, 2025
5cd1b2e
typo
fxmarty-amd Oct 6, 2025
1288977
fix typo
fxmarty-amd Oct 6, 2025
ec51387
fix tests
fxmarty-amd Oct 6, 2025
eb5f8f6
linting
fxmarty-amd Oct 6, 2025
fb0d1ac
skip test if amd-quark is not installed
fxmarty-amd Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions docs/features/quantization/quark.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,27 +231,31 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
--tasks gsm8k
```

## Using MXFP4 models
## Using OCP MX (MXFP4, MXFP6) models

vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
vLLM supports loading MXFP4 and MXFP6 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).

The scheme currently only supports dynamic quantization for activations.

Example usage, after installing the latest AMD Quark release:

```bash
vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1
# or, for a model using fp6 activations and fp4 weights:
vllm serve fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3 --tensor-parallel-size 1
```

A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16).
A simulation of the matrix multiplication execution in MXFP4/MXFP6 can be run on devices that do not support OCP MX operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from FP4/FP6 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate FP4/FP6 models using vLLM, or alternatively to benefit from the ~2.5-4x memory savings (compared to float16 and bfloat16).

To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example:

```bash
python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
--quant_scheme w_mxfp4_a_mxfp4_sym \
--quant_scheme w_mxfp4_a_mxfp4 \
--output_dir qwen_1.5-moe-a2.7b-mxfp4 \
--skip_evaluation \
--model_export hf_format \
--group_size 32
```

The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
import torch
from packaging import version

from vllm.model_executor.layers.quantization.quark.quark import (
QuarkLinearMethod,
QuarkW4A4MXFP4,
)
from vllm.model_executor.layers.quantization.quark.quark_moe import (
QuarkW4A4MXFp4MoEMethod,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer

Expand Down Expand Up @@ -63,9 +56,11 @@ def enable_pickle(monkeypatch):
@pytest.mark.parametrize(
"model_case",
[
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1),
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4),
],
)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
Expand All @@ -76,22 +71,33 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
f"{torch.cuda.device_count()}"
)

# `cuda_graph_sizes=[16]` to reduce load time.
with vllm_runner(
model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy"
model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy",
cuda_graph_sizes=[16],
) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model):
# from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
# QuarkLinearMethod)
# from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501
# from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
# QuarkOCP_MX_MoEMethod)

def check_model(model):
layer = model.model.layers[0]
# layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
# qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
# assert isinstance(qkv_proj.scheme, QuarkOCP_MX)

assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod)
# assert isinstance(layer.mlp.experts.quant_method,
# QuarkOCP_MX_MoEMethod)

if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
llm.apply_model(check_model)
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
# llm.apply_model(check_model)

output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
assert output
Expand Down
93 changes: 74 additions & 19 deletions tests/quantization/test_quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Optional

import huggingface_hub
import lm_eval
Expand Down Expand Up @@ -148,39 +149,93 @@ def get_state_dict(model):


@dataclass
class ModelCase:
model_id: str
tp: int


@dataclass
class GSM8KAccuracyTestConfig:
class AccuracyTestConfig:
model_name: str
excepted_value: float

def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
)


ACCURACY_CONFIGS = [
def get_model_args(
self,
tp_size: int,
model_max_len: Optional[int] = None,
kwargs: Optional[dict] = None,
) -> dict:
if kwargs is None:
kwargs = {}

model_args = {
"pretrained": self.model_name,
"dtype": "auto",
"add_bos_token": True,
"tensor_parallel_size": tp_size,
"gpu_memory_utilization": 0.7,
**kwargs,
}
if model_max_len is not None:
model_args["max_model_len"] = model_max_len

return model_args


GSM8K_ACCURACY_CONFIGS = [
# Private model.
GSM8KAccuracyTestConfig(
AccuracyTestConfig(
model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
excepted_value=0.96,
),
]

WIKITEXT_ACCURACY_CONFIGS = [
AccuracyTestConfig(
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
excepted_value=11.3,
),
AccuracyTestConfig(
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
excepted_value=10.6,
),
AccuracyTestConfig(
model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
),
]


@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
if torch.cuda.device_count() < tp_size:
pytest.skip(
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
)

task = "wikitext"
rtol = 0.1

# Smaller cuda_graph_sizes to speed up the test.
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
),
tasks=task,
batch_size=64,
)

EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["word_perplexity,none"]
assert (
measured_value < EXPECTED_VALUE + rtol
and measured_value > EXPECTED_VALUE - rtol
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.",
)
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
if torch.cuda.device_count() < 8:
pytest.skip(
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
Expand All @@ -193,7 +248,7 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):

results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
model_args=config.get_model_args(tp_size=8, model_max_len=38768),
tasks=task,
batch_size=64,
num_fewshot=8,
Expand Down
73 changes: 56 additions & 17 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_DTYPES,
OCP_MX_Scheme,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.utils import cdiv, has_triton_kernels
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
Expand All @@ -30,7 +34,7 @@ def _get_config_dtype_str(
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
ocp_mx_scheme: Optional[str] = None,
) -> Optional[str]:
"""
Return a string used to construct the filename that contains the
Expand All @@ -43,8 +47,11 @@ def _get_config_dtype_str(
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_mxfp4_w4a4:
return "mxfp4_w4a4"
elif ocp_mx_scheme is not None:
# The output of this function is passed to `try_get_optimal_moe_config`,
# and as we only simulate OCP MX execution in fused_moe for now,
# we will NOT look for `*,dtype=w_mxfp4_a_mxfp4.json` for now.
return None
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
Expand Down Expand Up @@ -289,8 +296,23 @@ def use_int4_w4a16(self) -> bool:
return self._a1.dtype is None and self._w1.dtype == "int4"

@property
def use_mxfp4_w4a4(self) -> bool:
return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4"
def ocp_mx_scheme(self) -> Union[str, None]:
if not hasattr(self, "_ocp_mx_scheme"):
if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or (
self._w1.dtype is not None and not isinstance(self._w1.dtype, str)
):
self._ocp_mx_scheme = None
else:
ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self._a1.dtype, self._w1.dtype
)

if ocp_mx_scheme is not None:
ocp_mx_scheme = ocp_mx_scheme.value

self._ocp_mx_scheme = ocp_mx_scheme

return self._ocp_mx_scheme

@property
def use_mxfp4_w4a16(self) -> bool:
Expand All @@ -310,7 +332,7 @@ def config_name(self, dtype: torch.dtype) -> Optional[str]:
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
ocp_mx_scheme=self.ocp_mx_scheme,
dtype=dtype,
)

Expand Down Expand Up @@ -371,12 +393,14 @@ def make(
w2_bias: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
weight_dtype: Union[torch.dtype, str, None] = None,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FusedMoEParallelConfig currently assume a common dtype for weights/activations, being quant_dtype. I added this weight_dtype to hopefully not break anything, but it is not clean.

Maybe quant_dtype is too vague

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. Each FusedMoEQuantDesc has it's own dtype. quant_dtype is meant for activations. The weights can have their own types which don't need to be the same.

) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
- quant_dtype: Optional quantization type. None if activations are
unquantized or quantized prior to calling. Note: "nvfp4" and
"mxfp4" are the only valid string values for quant_dtype.
unquantized or quantized prior to calling. Note: "nvfp4", "mxfp4",
"mxfp6_e3m2", "mxfp6_e2m3" are the only valid string values
for quant_dtype.
- per_act_token_quant: Activations have per token quantization.
- per_out_ch_quant: Outputs have per channel quantization. (only
for cutlass).
Expand All @@ -395,22 +419,33 @@ def make(
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
"""
assert (
not isinstance(quant_dtype, str)
or quant_dtype == "nvfp4"
or quant_dtype == "mxfp4"
)
assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4",
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
}
assert not isinstance(weight_dtype, str) or weight_dtype in {
"nvfp4",
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
}

if weight_dtype is None:
weight_dtype = quant_dtype

a_shape, w_shape = _quant_flags_to_group_shape(
quant_dtype, per_act_token_quant, per_out_ch_quant, block_shape
)
quant_config = FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
_w1=FusedMoEQuantDesc(
quant_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias
weight_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias
),
_w2=FusedMoEQuantDesc(
quant_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
)
assert quant_config.per_act_token_quant == per_act_token_quant
Expand Down Expand Up @@ -482,9 +517,11 @@ def mxfp4_w4a16_moe_quant_config(
)


def mxfp4_w4a4_moe_quant_config(
def ocp_mx_moe_quant_config(
quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
weight_dtype: Optional[str] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
Expand All @@ -494,8 +531,10 @@ def mxfp4_w4a4_moe_quant_config(
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
"""
assert quant_dtype in OCP_MX_DTYPES
return FusedMoEQuantConfig.make(
"mxfp4",
quant_dtype=quant_dtype,
weight_dtype=weight_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
Expand Down
Loading