Skip to content

[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations #29008

Merged
robertgshaw2-redhat merged 105 commits intovllm-project:mainfrom
xuebwang-amd:xuebin_add_quark_format_mapping_in_gpt_oss
Feb 10, 2026
Merged

[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations #29008
robertgshaw2-redhat merged 105 commits intovllm-project:mainfrom
xuebwang-amd:xuebin_add_quark_format_mapping_in_gpt_oss

Conversation

@xuebwang-amd
Copy link
Contributor

@xuebwang-amd xuebwang-amd commented Nov 19, 2025

Purpose

This PR aims for:

  • quark model loading, combined with mxfp4 loading function for original openai/gpt-oss-20b & openai/gpt-oss-120b
  • OCPMX_W4A16, OCPMX_W4AFP8 MoE scheme and emulation forward, unified into class QuarkOCP_MX_MoEMethod

Test Plan

  • Models:
    • GPT_OSS_20B
    • GPT_OSS_120B
  • Quantization schemes:
    • W: MXFP4, A: BF16, (optional) KV: FP8
    • W: MXFP4, A: FP8, (optional) KV: FP8
    • W: MXFP4, A: MXFP4, (optional) KV: FP8
  • TP:
    • TP1
    • TP2
    • TP4
    • TP8

See results below.

Test Result

image

(Sub)-tasks

TODO

  • (near-term) Loading without specifying model type.
  • (mid-term) GPT-OSS quantized in auto-mixed-precision, probably need to refactor current loading (and maybe inference) methodologies.

Note

Adds end-to-end support for GPT-OSS models in amd-quark format and extends MoE quantization.

  • Unifies GPT-OSS MoE weight loading for OpenAI and quark formats; handles fused experts, per-expert mapping, KV-cache scales, and bias, with EP/TP-aware slicing
  • New/updated quant configs: mxfp4_w4a16_moe_quant_config, mxfp4_fp8_moe_quant_config; fp8_w8a8_moe_quant_config/int8_w8a8_moe_quant_config accept bias; expanded OCP-MX schemes (incl. weight-only w_mxfp4 and *_a_fp8)
  • Fused MoE runtime: input quant dtype detection, activation QDQ emulation for *_a_fp8, MXFP4/MXFP6 dequant paths generalized; GPT-OSS Triton MoE path wired with PrecisionConfig handling
  • Layer changes: hidden-size roundup logic gated by model_type==gpt_oss and MXFP4; weight loader extended for GPT-OSS fused/bias paths; RMS KV-scale loader helper added
  • Quark MoE refactor: adds QuarkOCP_MX_MoEMethod (W4A16/W4Afp8), bias + static FP8 input scale handling, backend gating, and routing to ROCm AIter/Triton/native
  • Tests: new Triton kernel equivalence test for GPT-OSS MoE MXFP4 and an E2E GSM8k accuracy test for GPT-OSS quant models; removes old attention-only test

Written by Cursor Bugbot for commit e23834d. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 4ae66d1. Configure here.


Note

Extends GPT-OSS quantization support across model formats and MoE runtimes, with new configs, loaders, and tests.

  • Unifies GPT-OSS MoE weight loading for mxfp4 and quark formats: fused-expert/bias handling, EP/TP-aware slicing, expert mapping, and KV-cache scale loading; adds bias to qkv_proj/o_proj
  • New/updated quant configs: mxfp4_w4a16_moe_quant_config, mxfp4_w4a8_moe_quant_config; fp8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config, nvfp4_moe_quant_config now accept bias; adds use_mxfp4_w4a8
  • OCP-MX scheme expansion: support w_mxfp4 and *_a_fp8; generalized dequant/QDQ paths and input-quant dtype detection in fused_moe.py
  • GPT-OSS Triton MoE path: handles MXFP4 W4A16 with PrecisionConfig; errors for unimplemented FP8-activation path
  • Quark MoE refactor: QuarkOCP_MX_MoEMethod supports W4A16/W4Afp8 (emulation/native gating), bias/static FP8 scales, GPT-OSS padding/scale merge
  • Hidden-size roundup gated by model_type==gpt_oss and MXFP4; extended weight loader for GPT-OSS fused/bias cases
  • Tests: new Triton-kernel equivalence test for GPT-OSS MoE MXFP4 and E2E GSM8k accuracy test; removes old attention-only test

Written by Cursor Bugbot for commit 4ae66d1. This will update automatically on new commits. Configure here.


Note

Extends GPT‑OSS quantization and MoE execution across OpenAI and amd‑quark formats with new configs and loaders.

  • Unifies GPT‑OSS MoE weight loading for mxfp4 and quark (fused experts, EP/TP slicing, bias, expert mapping, KV‑cache scale loader); adds bias to qkv_proj/o_proj
  • Adds mxfp4_w4a16_moe_quant_config and mxfp4_w4a8_moe_quant_config; allows bias in fp8_w8a8_moe_quant_config/int8_w8a8_moe_quant_config/nvfp4_moe_quant_config; introduces use_mxfp4_w4a8
  • Expands OCP‑MX schemes (incl. w_mxfp4 and *_a_fp8) and generalizes fused MoE runtime: input quant dtype detection, MXFP4/MXFP6 dequant, FP8 activation QDQ emulation
  • GPT‑OSS Triton MoE path wired for MXFP4 W4A16 with PrecisionConfig; errors for unimplemented FP8‑act path; Precision handling detached from params
  • Rounds hidden size conditionally for model_type=gpt_oss with MXFP4; extends weight loader for GPT‑OSS fused/bias cases
  • Refactors Quark MoE: QuarkOCP_MX_MoEMethod (W4A16/W4Afp8), static FP8 input scale handling, backend gating (native/ROCm AIter/emulation)
  • Adds tests: Triton kernel equivalence for GPT‑OSS MoE (MXFP4) and E2E GSM8k accuracy; removes old attention‑only test

Written by Cursor Bugbot for commit 10c2323. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 887c716. Configure here.


Note

Extends GPT‑OSS quantization and MoE execution across OpenAI and amd‑quark formats with unified loaders, new configs, and runtime paths.

  • Unifies GPT‑OSS MoE weight loading for mxfp4 and quark (fused experts, bias, KV‑cache scales, EP/TP‑aware slicing); adds bias to qkv_proj/o_proj; introduces kv_cache_scale_loader
  • Adds mxfp4_w4a16_moe_quant_config and mxfp4_w4a8_moe_quant_config; fp8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config, nvfp4_moe_quant_config now accept bias; exposes use_mxfp4_w4a8
  • Expands OCP‑MX schemes (incl. w_mxfp4 and *_a_fp8) in ocp_mx_utils.py; generalizes fused MoE to detect input quant dtype, dequantize MXFP4/MXFP6, and emulate FP8 activations via QDQ
  • Wires GPT‑OSS Triton fused MoE for MXFP4 W4A16 with PrecisionConfig; errors for unimplemented FP8‑act kernel
  • Rounds hidden size conditionally for model_type=gpt_oss + MXFP4; enhances expert weight loader for GPT‑OSS fused/bias paths
  • Refactors Quark MoE (QuarkOCP_MX_MoEMethod) to support W4A16/W4Afp8 (native/ROCm AIter/emulation), static FP8 input scales, and backend gating
  • Tests: new Triton kernel equivalence for GPT‑OSS MoE (MXFP4) and an E2E GSM8k accuracy test; removes old attention‑only test

Written by Cursor Bugbot for commit 887c716. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 1bece5d. Configure here.


Note

Cursor Bugbot is generating a summary for commit e1e52ea. Configure here.


Note

Extends GPT‑OSS quantization and MoE execution across OpenAI and amd‑quark formats with unified loaders, new configs, and runtime updates.

  • Unifies GPT‑OSS MoE weight loading for mxfp4 and quark (fused experts, EP/TP slicing, expert mapping, KV‑cache scales, bias), and enables bias in qkv_proj/o_proj
  • Adds mxfp4_w4a16_moe_quant_config and mxfp4_w4a8_moe_quant_config; allows bias in fp8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config, nvfp4_moe_quant_config; introduces use_mxfp4_w4a8 and expands OCP‑MX schemes (e.g., w_mxfp4, *_a_fp8)
  • Generalizes fused MoE runtime: detects input quant dtype, dequantizes MXFP4/MXFP6, and emulates FP8 activations via QDQ for *_a_fp8
  • Wires GPT‑OSS Triton fused MoE for MXFP4 W4A16 with PrecisionConfig; errors for unimplemented FP8‑act kernel
  • Adjusts hidden‑size rounding for model_type=gpt_oss with MXFP4; enhances expert weight loader for GPT‑OSS fused/bias paths; adds kv_cache_scale_loader
  • Tests: adds Triton kernel equivalence for GPT‑OSS MoE (MXFP4) and an E2E GSM8k accuracy test; removes old attention‑only test

Written by Cursor Bugbot for commit e1e52ea. This will update automatically on new commits. Configure here.


Note

Extends GPT‑OSS quantization and MoE execution across OpenAI and amd‑quark formats with unified loaders, new quant configs, and runtime updates.

  • Unifies GPT‑OSS MoE weight loading (mxfp4 and quark): fused experts/bias handling, EP/TP‑aware slicing, expert mapping, and KV‑cache scale loading; enables bias in qkv_proj/o_proj
  • Adds mxfp4_w4a16_moe_quant_config and mxfp4_w4a8_moe_quant_config; allows bias in fp8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config, nvfp4_moe_quant_config; introduces use_mxfp4_w4a8; expands OCP_MX_Scheme (e.g., w_mxfp4, *_a_fp8)
  • Generalizes fused MoE runtime: input quant dtype detection, MXFP4/MXFP6 dequant paths, and FP8 activation QDQ emulation for *_a_fp8; GPT‑OSS Triton fused MoE supports MXFP4 W4A16 with PrecisionConfig
  • Adjusts hidden‑size roundup for model_type==gpt_oss with MXFP4; enhances expert weight loader for GPT‑OSS fused/bias cases; adds kv_cache_scale_loader
  • Tests: adds Triton‑kernel equivalence test for GPT‑OSS MoE (MXFP4) and an E2E GSM8k accuracy test; removes old attention‑only test

Written by Cursor Bugbot for commit 04bec4c. This will update automatically on new commits. Configure here.

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@mergify mergify bot added gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm labels Nov 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds weight mappings for the gpt-oss model to support the quark quantization format. The changes are in vllm/model_executor/models/gpt_oss.py.

My review identifies a critical issue in the new mappings. They are missing the .experts submodule path, which will likely cause weight loading to fail. I've provided a suggestion to correct this. This issue might also be present in the existing MoE weight mappings in the same file, which you may want to investigate as well.

Comment on lines +674 to +681
".gate_up_proj.weight": ".w13_weight",
".gate_up_proj.weight_scale": ".w13_weight_scale",
".gate_up_proj.bias": ".w13_bias",
".gate_up_proj.input_scale": ".w13_input_scale",
".down_proj.weight": ".w2_weight",
".down_proj.weight_scale": ".w2_weight_scale",
".down_proj.bias": ".w2_bias",
".down_proj.input_scale": ".w2_input_scale"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The weight mappings for the MoE layers appear to be missing the .experts submodule in the target path. The MoE parameters are located within the experts submodule of the MLPBlock, so the vLLM parameter names will be of the form ...mlp.experts.w13_weight, etc. The current mappings would incorrectly resolve to ...mlp.w13_weight, which would cause weight loading to fail.

To ensure the weights are loaded correctly, the .experts part should be included in the mapping.

Suggested change
".gate_up_proj.weight": ".w13_weight",
".gate_up_proj.weight_scale": ".w13_weight_scale",
".gate_up_proj.bias": ".w13_bias",
".gate_up_proj.input_scale": ".w13_input_scale",
".down_proj.weight": ".w2_weight",
".down_proj.weight_scale": ".w2_weight_scale",
".down_proj.bias": ".w2_bias",
".down_proj.input_scale": ".w2_input_scale"
".gate_up_proj.weight": ".experts.w13_weight",
".gate_up_proj.weight_scale": ".experts.w13_weight_scale",
".gate_up_proj.bias": ".experts.w13_bias",
".gate_up_proj.input_scale": ".experts.w13_input_scale",
".down_proj.weight": ".experts.w2_weight",
".down_proj.weight_scale": ".experts.w2_weight_scale",
".down_proj.bias": ".experts.w2_bias",
".down_proj.input_scale": ".experts.w2_input_scale"

@xuebwang-amd
Copy link
Contributor Author

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 19, 2025

@xuebwang-amd is there a model weights to test this feature? and can you share lm_eval score to know if it runs. I have been seeing many quark patches for the same GPTOSS model.

@xuebwang-amd xuebwang-amd marked this pull request as draft November 19, 2025 15:41
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…tor later)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@mergify
Copy link

mergify bot commented Dec 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xuebwang-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 11, 2025
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@xuebwang-amd xuebwang-amd marked this pull request as ready for review December 18, 2025 13:10
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

"weight_scale_2" in weight_name
if uses_weight_scale_2
else "weight_scale" in weight_name
) or "input_scale" in weight_name
if is_per_tensor:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)

P1 Badge Supply combined_w13 argument when loading per-tensor scales

The _load_per_tensor_weight_scale signature now requires combined_w13, but the ModelOpt per-tensor path still calls it without that argument. Hitting this branch will raise TypeError: _load_per_tensor_weight_scale() missing 1 required positional argument: 'combined_w13', preventing ModelOpt MoE checkpoints from loading.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +562 to +576
# Round up hidden size if needed.
hidden_size, is_rounded_hidden_size = maybe_roundup_hidden_size(
hidden_size,
moe_in_dtype,
self.moe_parallel_config,
self.model_type,
self.is_mxfp4_quant,
self.emulate_quant,
is_lora_enabled=self.vllm_config.lora_config is not None,
)
print(f"is_rounded_hidden_size is {is_rounded_hidden_size}")

if is_rounded_hidden_size:
self.hidden_size = hidden_size
self.moe_config: FusedMoEConfig = FusedMoEConfig(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor hidden_size padding for non-gpt_oss MoE

The padding result from maybe_roundup_hidden_size is only applied when is_rounded_hidden_size is true, yet that flag is set only for gpt_oss+mxfp4 in the helper (lines 250‑269). When other backends such as DeepEP round the hidden size inside maybe_roundup_layer_hidden_size, is_rounded_hidden_size stays false, so self.hidden_size/self.moe_config remain at the unpadded size while weights are built with the padded hidden_size later in __init__. DeepEP runs will then have layer metadata smaller than the actual weight shapes, leading to buffer shape mismatches at runtime.

Useful? React with 👍 / 👎.

@xuebwang-amd
Copy link
Contributor Author

@xuebwang-amd is there a model weights to test this feature? and can you share lm_eval score to know if it runs. I have been seeing many quark patches for the same GPTOSS model.

Thanks @tjtanaa. Yes, there have been several works on GPTOSS model recently from amd-quark side. They are:

  • PR#27334: for GPT-OSS attention quantization
  • PR#27980: for GPT-OSS kv cache quantization (it will be absorbed into this PR)
  • PR#28638: for quark config postprocessed by apply_vllm_mapper
  • PR#29008 (current one):
    • quark model loading
    • OCPMX_W4A16, OCPMX_W4AFP8 MoE scheme and emulation forward

Please see more infos including lm_eval accuracy results on the top of PR descriptions.

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@xuebwang-amd
Copy link
Contributor Author

Hi @robertgshaw2-redhat , following up on our discussion, the requested changes have now been implemented. Could you please have a double look if this is ready to go? Thank you!

@github-project-automation github-project-automation bot moved this from In progress to Ready in gpt-oss Issues & Enhancements Feb 10, 2026
@robertgshaw2-redhat robertgshaw2-redhat merged commit b129136 into vllm-project:main Feb 10, 2026
68 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 10, 2026
gshtras pushed a commit to ROCm/vllm that referenced this pull request Feb 10, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
samutamm pushed a commit to samutamm/vllm that referenced this pull request Feb 11, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
gshtras pushed a commit to ROCm/vllm that referenced this pull request Feb 11, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
gshtras pushed a commit to ROCm/vllm that referenced this pull request Feb 19, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
gshtras pushed a commit to ROCm/vllm that referenced this pull request Feb 20, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
gshtras pushed a commit to ROCm/vllm that referenced this pull request Feb 20, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Comment on lines +251 to +252
if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
pass # No QDQ needed for these schemes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuebwang-amd this looks unnecessary. quant_dtype should already be properly set.

Comment on lines +706 to +709
self._emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuebwang-amd this is not correct. w_mxfp4_a_mxfp6 models can not run through aiter backend

tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…ulations (vllm-project#29008)

Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants