Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
372cb3e
add quark format moe in WeightsMapper
xuebwang-amd Nov 19, 2025
36657bc
add _load_weights_mxfp4 and refactor get_expert_mapping
xuebwang-amd Nov 20, 2025
900c3ad
add create bias weights
xuebwang-amd Nov 20, 2025
24dd58e
update create weights for moe bias
xuebwang-amd Nov 26, 2025
34517af
adjust w13 & w2 weight shape to fix aiter shuffle weights issue
xuebwang-amd Nov 26, 2025
3b35794
maybe_roundup_hidden_size for amd quark senarios
xuebwang-amd Dec 9, 2025
f7db770
support both fp8 and mxfp4 quark model loading (would probobaly refac…
xuebwang-amd Dec 11, 2025
939aea6
clean and add TODO
xuebwang-amd Dec 11, 2025
4a1c93a
add moe bias into quant config
xuebwang-amd Dec 12, 2025
4adf440
This is a main commit containing ensential changes
xuebwang-amd Dec 18, 2025
4ae0583
Resolved merge conflicts and associated typings
xuebwang-amd Dec 19, 2025
145c462
remove some comments and fix some pre-commit issues
xuebwang-amd Dec 19, 2025
c4a05ed
fix a derived problem from issue #30621
xuebwang-amd Dec 19, 2025
abd7003
update leveraging latest apply
xuebwang-amd Dec 19, 2025
09ff1b9
remove emulation condition to fix tp4/8 model loading
xuebwang-amd Dec 19, 2025
bec9310
unify original mxfp4 model loading and quark model (both mxfp4 and fp…
xuebwang-amd Dec 22, 2025
52f5662
unify OCPMX_W4A16, OCPMX_W4AFP8 into QuarkOCP_MX_MoEMethod
xuebwang-amd Dec 22, 2025
8cceb81
fix some pre-commit issues (to be continue)
xuebwang-amd Dec 22, 2025
be16ad6
fix all pre-commit errors at once
xuebwang-amd Dec 23, 2025
fc029e0
quark w8a8 fp8 for gpt_oss
xuebwang-amd Dec 23, 2025
d28c80d
rename test_gpt_oss_attn_quantization.py to include all tests for gpt…
xuebwang-amd Dec 23, 2025
accd286
udpate test script
xuebwang-amd Dec 25, 2025
d686235
fix tiny pre-commit issue in test_gpt_oss.py
xuebwang-amd Dec 25, 2025
8bec3f1
fix fp8 weight loading
xuebwang-amd Jan 4, 2026
8e0b767
a tiny refactor to make input_scale loading compact
xuebwang-amd Jan 4, 2026
8e3bc67
fix a (last) fp8 loading issue
xuebwang-amd Jan 5, 2026
6a3dfcc
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Jan 5, 2026
85a98ec
fix docs/readthedocs.org error
xuebwang-amd Jan 6, 2026
16fdecb
fix ci issues
xuebwang-amd Jan 6, 2026
eb2b024
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Jan 6, 2026
aac4075
undo one fixing associated with code lint
xuebwang-amd Jan 6, 2026
2ae8be3
fix a mxfp4 loading issue
xuebwang-amd Jan 6, 2026
4b5714f
add quant_config variables to avoid NotImplementedError
xuebwang-amd Jan 7, 2026
506cb6c
fix using prepared moe quant config
xuebwang-amd Jan 7, 2026
382c40a
Resolved merge conflict
xuebwang-amd Jan 7, 2026
09c8bc5
tiny update after rebase
xuebwang-amd Jan 7, 2026
6e05a00
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Jan 7, 2026
e23834d
update tests/models/quantization/test_gpt_oss.py
xuebwang-amd Jan 9, 2026
4ae66d1
rename mxfp4_fp8_moe_quant_config to mxfp4_w4a8_moe_quant_config
xuebwang-amd Jan 12, 2026
10c2323
fix a typo
xuebwang-amd Jan 12, 2026
887c716
update using correct/reasonable fixings from cursor
xuebwang-amd Jan 12, 2026
1bece5d
add weight-only mxfp6 by cursor
xuebwang-amd Jan 12, 2026
e1e52ea
update using fixings from cursor
xuebwang-amd Jan 12, 2026
04bec4c
update fixing/changes from cursor
xuebwang-amd Jan 12, 2026
1075a62
Resolved merge conflicts
xuebwang-amd Jan 13, 2026
41c90cc
tiny update from cursor side
xuebwang-amd Jan 13, 2026
b6dadb0
update using fixings from cursor
xuebwang-amd Jan 13, 2026
bde3fd8
refactor: moving ocp_mx based emulation activation qdq out of fused_m…
xuebwang-amd Jan 13, 2026
6e439e7
remove tensor to PrecisionConfig cast
xuebwang-amd Jan 13, 2026
986f331
update according to cursor findings
xuebwang-amd Jan 13, 2026
ab5ed9c
tiny update from cursor
xuebwang-amd Jan 13, 2026
480abd7
Resolved merge conflicts
xuebwang-amd Jan 16, 2026
16c6803
one lint fixing
xuebwang-amd Jan 16, 2026
e0a0395
Resolved merge conflicts
xuebwang-amd Jan 16, 2026
09cc096
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Jan 16, 2026
d5fa792
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Jan 16, 2026
905f9ad
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Jan 17, 2026
dfb0867
Resolved merge conflicts
xuebwang-amd Jan 19, 2026
1e72f89
refactor for rounded_hidden_size logics
xuebwang-amd Jan 21, 2026
d4b08b1
Resolved merge conflicts
xuebwang-amd Jan 21, 2026
7b24bb7
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Jan 22, 2026
89810eb
update refactored layer.py
xuebwang-amd Jan 22, 2026
1d589d9
improve fp8 loading for gpt_oss
xuebwang-amd Jan 23, 2026
66f06dc
SIMPLIFY AND REFACTOR loading flow
xuebwang-amd Jan 23, 2026
c593eda
update test case
xuebwang-amd Jan 23, 2026
09d8e6f
seperate _load_weights_mxfp4 and _load_weights_quark, and fix per_ran…
xuebwang-amd Jan 27, 2026
434096f
Resolved merge conflicts
xuebwang-amd Jan 27, 2026
75fb697
simplify and fixings
xuebwang-amd Jan 27, 2026
a6252a3
refactor FusedMoE init and revert gpt_oss_triton_kernels_moe
xuebwang-amd Jan 29, 2026
39dcd24
large refactoring: remove mxfp4_or_bias centered special loading and …
xuebwang-amd Jan 29, 2026
831d2e2
some further refactors
xuebwang-amd Feb 2, 2026
0330e89
Resolved merge conflicts
xuebwang-amd Feb 2, 2026
8b664d3
refactor is_mxfp4_quant as a method of the QuantizationConfig
xuebwang-amd Feb 3, 2026
2e54937
fix a pre-commit error
xuebwang-amd Feb 3, 2026
b85d2b9
revert tp_size, tp_rank for others loading
xuebwang-amd Feb 3, 2026
6d19d9c
use narrow_weight naming
xuebwang-amd Feb 4, 2026
813dd82
update _load_weights_mxfp4
xuebwang-amd Feb 4, 2026
671f36c
refactor maybe_roundup_hidden_size call chain
xuebwang-amd Feb 4, 2026
443ab7a
updated
Feb 5, 2026
fd26137
revert maybe_roundup_hidden_size
Feb 5, 2026
318e8c1
revert breakage
Feb 5, 2026
0ff9e99
attempt to simplify weight loading logic
Feb 5, 2026
475d30f
attempt to simplify weight loading logic
Feb 5, 2026
59de1b6
attempt to simplify weight loading logic
Feb 5, 2026
2612a3b
attempt to simplify weight loading logic
Feb 5, 2026
dd450d9
fixed
Feb 5, 2026
9c4f959
fixed
Feb 5, 2026
22beb32
rever the kv cache scale loader into gpt-oss
Feb 5, 2026
36306b5
reduce loc changed
Feb 5, 2026
ca6af3d
update comment
Feb 5, 2026
daa2305
Merge branch 'main' into xuebin_add_quark_format_mapping_in_gpt_oss
robertgshaw2-redhat Feb 5, 2026
548b183
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 6, 2026
4c4ec3d
keep shard_id=None for mxfp4
xuebwang-amd Feb 6, 2026
53c1020
tiny update
xuebwang-amd Feb 6, 2026
0eda07a
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 6, 2026
af1ae15
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 6, 2026
e8cfdf0
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 6, 2026
03eb487
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 9, 2026
0886d8c
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 9, 2026
a5e872a
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 9, 2026
b590751
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 9, 2026
58e52e1
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 10, 2026
11eb2fd
fix an error exposed from CI
xuebwang-amd Feb 10, 2026
315a2f0
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd Feb 10, 2026
c62f664
remove model_type as an attribute of FusedMoE
xuebwang-amd Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions tests/kernels/moe/test_gpt_oss_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close

from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward,
)
Expand Down Expand Up @@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)

quant_config = FusedMoEQuantConfig.make(
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_scale=pc1,
w2_scale=pc2,
)
if a_dtype == "bf16" and w_dtype == "mx4":
quant_config = mxfp4_w4a16_moe_quant_config(
w1_scale=pc1,
w2_scale=pc2,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
)
else:
raise NotImplementedError(
f"Quantization configuration for activation={a_dtype} and weight={w_dtype} "
f"has not been implemented."
)

out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri,
Expand Down
110 changes: 110 additions & 0 deletions tests/models/quantization/test_gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
End-to-end accuracy test for GPT-OSS model quantization.

Config:
Task: gsm8k_platinum
Filter: flexible-extract
n-shot: 5
Metric: exact_match

Run: pytest tests/models/quantization/test_gpt_oss.py
"""

import importlib
import importlib.metadata
from dataclasses import dataclass

import huggingface_hub
import lm_eval
import pytest
from packaging import version

MODEL_ACCURACIES = {
# Full quantization: attention linears and MoE linears
"amd/gpt-oss-20b-WFP8-AFP8-KVFP8": 0.89,
# MoE linears only quantization
"amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8": 0.89,
# MoE linears only quantization
# "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-MXFP4-KV-FP8": 0.90,
}

QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")
) >= version.parse("0.9.0")


def has_huggingface_access(repo):
try:
huggingface_hub.list_repo_refs(repo)
return True
except huggingface_hub.errors.RepositoryNotFoundError:
return False


HF_HUB_AMD_ORG_ACCESS = all(
[has_huggingface_access(model_name) for model_name in MODEL_ACCURACIES]
)


@dataclass
class ModelCase:
model_id: str
tp: int


@dataclass
class EvaluationConfig:
model_name: str

def get_model_args(self, tp_size: int):
return {
"pretrained": self.model_name,
"chat_template_args": {"reasoning_effort": "low"},
"enable_thinking": True,
"think_end_token": "200008",
"tensor_parallel_size": tp_size,
"dtype": "auto",
"gpu_memory_utilization": 0.95,
"trust_remote_code": False,
"enable_prefix_caching": False,
"enforce_eager": False,
}


@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.",
)
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items())
def test_gpt_oss_attention_quantization(
model_name: str, tp_size: int, expected_accuracy: float
):
model_args = EvaluationConfig(model_name).get_model_args(tp_size)

extra_run_kwargs = {
"gen_kwargs": {"max_gen_toks": 8000},
"apply_chat_template": True,
"fewshot_as_multiturn": True,
"num_fewshot": 5,
}

lm_eval_out = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks="gsm8k_platinum",
batch_size="auto",
**extra_run_kwargs,
)
measured_accuracy = float(
lm_eval_out["results"]["gsm8k_platinum"]["exact_match,flexible-extract"]
)

rtol = 0.02
assert (
measured_accuracy - rtol < expected_accuracy
and measured_accuracy + rtol > expected_accuracy
), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}"
80 changes: 0 additions & 80 deletions tests/models/quantization/test_gpt_oss_attn_quantization.py

This file was deleted.

36 changes: 36 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,10 @@ def use_mxfp4_w4a4(self) -> bool:
def use_nvfp4_w4a4(self) -> bool:
return self.quant_dtype == "nvfp4"

@property
def use_mxfp4_w4a8(self) -> bool:
return self._a1.dtype == "fp8" and self._w1.dtype == "mxfp4"

def config_name(self, dtype: torch.dtype) -> str | None:
"""
Return a string used to construct the filename that contains the
Expand Down Expand Up @@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config(
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: list[int] | None = None,
Expand All @@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config(
g1_alphas=g1_alphas,
w2_scale=w2_scale,
g2_alphas=g2_alphas,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_scale=a1_scale,
a1_gscale=a1_gscale,
a2_scale=a2_scale,
Expand All @@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config(
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig:
"""
Expand All @@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=None,
Expand Down Expand Up @@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config(
)


def mxfp4_w4a8_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc("fp8", None, a1_scale, None, None, None),
_a2=FusedMoEQuantDesc("fp8", None, a2_scale, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)


def ocp_mx_moe_quant_config(
quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
Expand Down Expand Up @@ -691,6 +723,8 @@ def nvfp4_moe_quant_config(
a2_gscale: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and nvp4 weights.
Expand All @@ -699,6 +733,8 @@ def nvfp4_moe_quant_config(
"nvfp4",
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
g1_alphas=g1_alphas,
Expand Down
Loading