-
-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations #29008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 105 commits into
vllm-project:main
from
xuebwang-amd:xuebin_add_quark_format_mapping_in_gpt_oss
Feb 10, 2026
Merged
Changes from all commits
Commits
Show all changes
105 commits
Select commit
Hold shift + click to select a range
372cb3e
add quark format moe in WeightsMapper
xuebwang-amd 36657bc
add _load_weights_mxfp4 and refactor get_expert_mapping
xuebwang-amd 900c3ad
add create bias weights
xuebwang-amd 24dd58e
update create weights for moe bias
xuebwang-amd 34517af
adjust w13 & w2 weight shape to fix aiter shuffle weights issue
xuebwang-amd 3b35794
maybe_roundup_hidden_size for amd quark senarios
xuebwang-amd f7db770
support both fp8 and mxfp4 quark model loading (would probobaly refac…
xuebwang-amd 939aea6
clean and add TODO
xuebwang-amd 4a1c93a
add moe bias into quant config
xuebwang-amd 4adf440
This is a main commit containing ensential changes
xuebwang-amd 4ae0583
Resolved merge conflicts and associated typings
xuebwang-amd 145c462
remove some comments and fix some pre-commit issues
xuebwang-amd c4a05ed
fix a derived problem from issue #30621
xuebwang-amd abd7003
update leveraging latest apply
xuebwang-amd 09ff1b9
remove emulation condition to fix tp4/8 model loading
xuebwang-amd bec9310
unify original mxfp4 model loading and quark model (both mxfp4 and fp…
xuebwang-amd 52f5662
unify OCPMX_W4A16, OCPMX_W4AFP8 into QuarkOCP_MX_MoEMethod
xuebwang-amd 8cceb81
fix some pre-commit issues (to be continue)
xuebwang-amd be16ad6
fix all pre-commit errors at once
xuebwang-amd fc029e0
quark w8a8 fp8 for gpt_oss
xuebwang-amd d28c80d
rename test_gpt_oss_attn_quantization.py to include all tests for gpt…
xuebwang-amd accd286
udpate test script
xuebwang-amd d686235
fix tiny pre-commit issue in test_gpt_oss.py
xuebwang-amd 8bec3f1
fix fp8 weight loading
xuebwang-amd 8e0b767
a tiny refactor to make input_scale loading compact
xuebwang-amd 8e3bc67
fix a (last) fp8 loading issue
xuebwang-amd 6a3dfcc
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 85a98ec
fix docs/readthedocs.org error
xuebwang-amd 16fdecb
fix ci issues
xuebwang-amd eb2b024
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd aac4075
undo one fixing associated with code lint
xuebwang-amd 2ae8be3
fix a mxfp4 loading issue
xuebwang-amd 4b5714f
add quant_config variables to avoid NotImplementedError
xuebwang-amd 506cb6c
fix using prepared moe quant config
xuebwang-amd 382c40a
Resolved merge conflict
xuebwang-amd 09c8bc5
tiny update after rebase
xuebwang-amd 6e05a00
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd e23834d
update tests/models/quantization/test_gpt_oss.py
xuebwang-amd 4ae66d1
rename mxfp4_fp8_moe_quant_config to mxfp4_w4a8_moe_quant_config
xuebwang-amd 10c2323
fix a typo
xuebwang-amd 887c716
update using correct/reasonable fixings from cursor
xuebwang-amd 1bece5d
add weight-only mxfp6 by cursor
xuebwang-amd e1e52ea
update using fixings from cursor
xuebwang-amd 04bec4c
update fixing/changes from cursor
xuebwang-amd 1075a62
Resolved merge conflicts
xuebwang-amd 41c90cc
tiny update from cursor side
xuebwang-amd b6dadb0
update using fixings from cursor
xuebwang-amd bde3fd8
refactor: moving ocp_mx based emulation activation qdq out of fused_m…
xuebwang-amd 6e439e7
remove tensor to PrecisionConfig cast
xuebwang-amd 986f331
update according to cursor findings
xuebwang-amd ab5ed9c
tiny update from cursor
xuebwang-amd 480abd7
Resolved merge conflicts
xuebwang-amd 16c6803
one lint fixing
xuebwang-amd e0a0395
Resolved merge conflicts
xuebwang-amd 09cc096
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd d5fa792
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 905f9ad
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd dfb0867
Resolved merge conflicts
xuebwang-amd 1e72f89
refactor for rounded_hidden_size logics
xuebwang-amd d4b08b1
Resolved merge conflicts
xuebwang-amd 7b24bb7
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 89810eb
update refactored layer.py
xuebwang-amd 1d589d9
improve fp8 loading for gpt_oss
xuebwang-amd 66f06dc
SIMPLIFY AND REFACTOR loading flow
xuebwang-amd c593eda
update test case
xuebwang-amd 09d8e6f
seperate _load_weights_mxfp4 and _load_weights_quark, and fix per_ran…
xuebwang-amd 434096f
Resolved merge conflicts
xuebwang-amd 75fb697
simplify and fixings
xuebwang-amd a6252a3
refactor FusedMoE init and revert gpt_oss_triton_kernels_moe
xuebwang-amd 39dcd24
large refactoring: remove mxfp4_or_bias centered special loading and …
xuebwang-amd 831d2e2
some further refactors
xuebwang-amd 0330e89
Resolved merge conflicts
xuebwang-amd 8b664d3
refactor is_mxfp4_quant as a method of the QuantizationConfig
xuebwang-amd 2e54937
fix a pre-commit error
xuebwang-amd b85d2b9
revert tp_size, tp_rank for others loading
xuebwang-amd 6d19d9c
use narrow_weight naming
xuebwang-amd 813dd82
update _load_weights_mxfp4
xuebwang-amd 671f36c
refactor maybe_roundup_hidden_size call chain
xuebwang-amd 443ab7a
updated
fd26137
revert maybe_roundup_hidden_size
318e8c1
revert breakage
0ff9e99
attempt to simplify weight loading logic
475d30f
attempt to simplify weight loading logic
59de1b6
attempt to simplify weight loading logic
2612a3b
attempt to simplify weight loading logic
dd450d9
fixed
9c4f959
fixed
22beb32
rever the kv cache scale loader into gpt-oss
36306b5
reduce loc changed
ca6af3d
update comment
daa2305
Merge branch 'main' into xuebin_add_quark_format_mapping_in_gpt_oss
robertgshaw2-redhat 548b183
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 4c4ec3d
keep shard_id=None for mxfp4
xuebwang-amd 53c1020
tiny update
xuebwang-amd 0eda07a
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd af1ae15
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd e8cfdf0
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 03eb487
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 0886d8c
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd a5e872a
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd b590751
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 58e52e1
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd 11eb2fd
fix an error exposed from CI
xuebwang-amd 315a2f0
Merge remote-tracking branch 'origin/main' into xuebin_add_quark_form…
xuebwang-amd c62f664
remove model_type as an attribute of FusedMoE
xuebwang-amd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| End-to-end accuracy test for GPT-OSS model quantization. | ||
|
|
||
| Config: | ||
| Task: gsm8k_platinum | ||
| Filter: flexible-extract | ||
| n-shot: 5 | ||
| Metric: exact_match | ||
|
|
||
| Run: pytest tests/models/quantization/test_gpt_oss.py | ||
| """ | ||
|
|
||
| import importlib | ||
| import importlib.metadata | ||
| from dataclasses import dataclass | ||
|
|
||
| import huggingface_hub | ||
| import lm_eval | ||
| import pytest | ||
| from packaging import version | ||
|
|
||
| MODEL_ACCURACIES = { | ||
| # Full quantization: attention linears and MoE linears | ||
| "amd/gpt-oss-20b-WFP8-AFP8-KVFP8": 0.89, | ||
| # MoE linears only quantization | ||
| "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8": 0.89, | ||
| # MoE linears only quantization | ||
| # "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-MXFP4-KV-FP8": 0.90, | ||
| } | ||
|
|
||
| QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( | ||
| importlib.metadata.version("amd-quark") | ||
| ) >= version.parse("0.9.0") | ||
|
|
||
|
|
||
| def has_huggingface_access(repo): | ||
| try: | ||
| huggingface_hub.list_repo_refs(repo) | ||
| return True | ||
| except huggingface_hub.errors.RepositoryNotFoundError: | ||
| return False | ||
|
|
||
|
|
||
| HF_HUB_AMD_ORG_ACCESS = all( | ||
| [has_huggingface_access(model_name) for model_name in MODEL_ACCURACIES] | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelCase: | ||
| model_id: str | ||
| tp: int | ||
|
|
||
|
|
||
| @dataclass | ||
| class EvaluationConfig: | ||
| model_name: str | ||
|
|
||
| def get_model_args(self, tp_size: int): | ||
| return { | ||
| "pretrained": self.model_name, | ||
| "chat_template_args": {"reasoning_effort": "low"}, | ||
| "enable_thinking": True, | ||
| "think_end_token": "200008", | ||
| "tensor_parallel_size": tp_size, | ||
| "dtype": "auto", | ||
| "gpu_memory_utilization": 0.95, | ||
| "trust_remote_code": False, | ||
| "enable_prefix_caching": False, | ||
| "enforce_eager": False, | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") | ||
| @pytest.mark.skipif( | ||
| not HF_HUB_AMD_ORG_ACCESS, | ||
| reason="Read access to huggingface.co/amd is required for this test.", | ||
| ) | ||
| @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) | ||
| @pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items()) | ||
| def test_gpt_oss_attention_quantization( | ||
| model_name: str, tp_size: int, expected_accuracy: float | ||
| ): | ||
| model_args = EvaluationConfig(model_name).get_model_args(tp_size) | ||
|
|
||
| extra_run_kwargs = { | ||
| "gen_kwargs": {"max_gen_toks": 8000}, | ||
| "apply_chat_template": True, | ||
| "fewshot_as_multiturn": True, | ||
| "num_fewshot": 5, | ||
| } | ||
|
|
||
| lm_eval_out = lm_eval.simple_evaluate( | ||
| model="vllm", | ||
| model_args=model_args, | ||
| tasks="gsm8k_platinum", | ||
| batch_size="auto", | ||
| **extra_run_kwargs, | ||
| ) | ||
| measured_accuracy = float( | ||
| lm_eval_out["results"]["gsm8k_platinum"]["exact_match,flexible-extract"] | ||
| ) | ||
|
|
||
| rtol = 0.02 | ||
| assert ( | ||
| measured_accuracy - rtol < expected_accuracy | ||
| and measured_accuracy + rtol > expected_accuracy | ||
| ), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}" |
80 changes: 0 additions & 80 deletions
80
tests/models/quantization/test_gpt_oss_attn_quantization.py
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.