Skip to content

[MoE Refactor] Mxfp4 oracle rebased#37128

Open
zyongye wants to merge 27 commits intovllm-project:mainfrom
zyongye:mxfp4_oracle_rebased
Open

[MoE Refactor] Mxfp4 oracle rebased#37128
zyongye wants to merge 27 commits intovllm-project:mainfrom
zyongye:mxfp4_oracle_rebased

Conversation

@zyongye
Copy link
Member

@zyongye zyongye commented Mar 15, 2026

Purpose

Rebased and improve version of #34983
Ongoing MXFP4 MoE refactor

  • Refactor MXFP4 MoE from a monolithic 1299-line Mxfp4MoEMethod class to the oracle pattern used by FP8 and NvFP4

    • Create oracle/mxfp4.py with backend selection, weight conversion, quant config, and kernel assembly
    • Create TrtLlmMxfp4ExpertsMonolithic wrapping trtllm_fp4_block_scale_moe() (both BF16 and MXFP8 input modes)
    • Create OAITritonMxfp4ExpertsMonolithic wrapping triton_kernel_moe_forward()
    • Implement _supports_* methods on TrtLlmGenExperts and BaseOAITritonExperts so they work with the oracle
    • Consolidate monolithic/modular into single backend enum values following FP8 pattern (backend_to_kernel_cls returns
      [Monolithic, Modular] list, is_supported_config selects naturally)
    • Move MXFP4-specific hidden_size rounding from layer.py into oracle
    • Simplify mxfp4.py from 1299 to ~430 lines — all inline kernel code removed

    Test plan

    GPQA eval on gpt-oss-120b (target ~65.3%):

    Backend default TP=2 DP=2 EP
    FLASHINFER TRTLLM MXFP4+BF16 (monolithic/modular) 65.53% 65.53% 66.35%
    FLASHINFER TRTLLM MXFP4+MXFP8 (monolithic/modular) 64.39% 65.21% 66.09%
    FLASHINFER CUTLASS MXFP4+MXFP8 (modular) 65.66% 64.39% fail on main
    TRITON monolithic/modular 65.66% 66.41% See Note

Triton kernel for DEP=2 will hit this error but both model generate reasonable outputs.

DP/EP for CUTLASS mxfp8 are failing due to kernel reason. Since it fails on main, we can defer to later PRs.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring of the MXFP4 MoE implementation. Moving the backend-specific logic into an oracle pattern greatly improves modularity and maintainability, as evidenced by the substantial reduction in code complexity within mxfp4.py. The new structure is much cleaner and easier to follow. I have one critical comment regarding proper object initialization in a base class, which could lead to subtle bugs.

Comment on lines +34 to +40
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
self.moe_config = moe_config
self.quant_config = quant_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The __init__ method of TrtLlmMxfp4ExpertsBase should call super().__init__ to ensure proper initialization of all base classes in its subclasses that use multiple inheritance, such as TrtLlmMxfp4ExpertsMonolithic. Currently, mk.FusedMoEExperts.__init__ is not being called due to the Method Resolution Order (MRO), which can lead to incomplete object initialization and potential bugs if the base class implementation changes. The explicit assignments to self.moe_config and self.quant_config can then be removed as the superclass __init__ already handles them.

Suggested change
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
self.moe_config = moe_config
self.quant_config = quant_config
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)

@mergify
Copy link

mergify bot commented Mar 16, 2026

Documentation preview: https://vllm--37128.org.readthedocs.build/en/37128/

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 16, 2026
zyongye and others added 6 commits March 16, 2026 02:43
…Expert classes

Refactor MXFP4 MoE from a monolithic 1299-line class to the oracle
pattern used by FP8 and NvFP4. This separates backend selection,
weight conversion, and kernel assembly into oracle/mxfp4.py, and
wraps monolithic kernel calls into proper FusedMoEExpertsMonolithic
classes.

Key changes:
- New oracle/mxfp4.py: Mxfp4MoeBackend enum, select_mxfp4_moe_backend(),
  convert_to_mxfp4_moe_kernel_format(), make_mxfp4_moe_quant_config(),
  make_mxfp4_moe_kernel(), mxfp4_round_up_hidden_size_and_intermediate_size()
- New experts/trtllm_mxfp4_moe.py: TrtLlmMxfp4ExpertsMonolithic wrapping
  trtllm_fp4_block_scale_moe() for both BF16 and MXFP8 input modes
- New OAITritonMxfp4ExpertsMonolithic in gpt_oss_triton_kernels_moe.py
  wrapping triton_kernel_moe_forward()
- Implemented _supports_* on TrtLlmGenExperts and BaseOAITritonExperts
- Removed max_capture_size param from TrtLlmGenExperts (computed internally)
- Moved MXFP4-specific hidden_size rounding from layer.py to oracle
- Simplified mxfp4.py to thin orchestrator (1299 -> ~456 lines)
- Updated quark_moe.py imports to use oracle

GPQA eval results on gpt-oss-120b (target ~65.3%):
- MXFP4+BF16 monolithic (default): 65.53%
- MXFP4+BF16 monolithic (TP=2):    65.53%
- MXFP4+MXFP8 monolithic (default): 64.39%
- MXFP4+MXFP8 monolithic (TP=2):    65.21%
- CUTLASS MXFP4+MXFP8 modular (default): 65.66%
- CUTLASS MXFP4+MXFP8 modular (TP=2):    64.39%
- DP=2 EP (all backends): fail on main (pre-existing assertion)
- Triton monolithic: fail on main (pre-existing shape assertion on SM100)

CK (ROCm) backend refactor deferred to follow-up PR.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Follow the FP8 oracle pattern: use a single backend enum value
(e.g. FLASHINFER_TRTLLM_MXFP4_BF16) for both monolithic and modular
expert classes. backend_to_kernel_cls() returns a list
[Monolithic, Modular] and _return_or_raise() iterates it, letting
is_supported_config() naturally select the right class based on the
deployment configuration (EP → modular, no EP → monolithic).

Removes: FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC,
FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC, TRITON_MONOLITHIC enums
and MONOLITHIC_BACKENDS tuple. is_monolithic is now determined by
self.moe_mk.is_monolithic (from the selected expert class).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…m_mxfp4_moe.py

P0 fixes:
- Remove dead unreachable code in convert_to_mxfp4_moe_kernel_format
- Remove unused layer.gemm1_alpha/beta/clamp_limit from weight conversion
  (expert classes create their own copies)
- Remove dead _can_support_mxfp4 function and unused imports from
  mxfp4_utils.py

P1 fixes:
- Use public quant_config.quant_dtype instead of private _a1.dtype
- Add comment explaining in-place FusedMoEConfig mutation for dim rounding

Move TrtLlmGenExperts from trtllm_moe.py into experts/trtllm_mxfp4_moe.py
as TrtLlmMxfp4ExpertsModular (with backward-compatible TrtLlmGenExperts
alias). Shared base class provides _supports_*, gemm1_alpha/beta/clamp_limit,
and max_capture_size for both modular and monolithic variants. Delete
trtllm_moe.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
- Issue 7: Use use_batched_activation_format property instead of
  use_deepep_ll_kernels for activation format selection (consistent
  with FP8 oracle)
- Issue 8: Fix inconsistent log message quoting for backend names
- Issue 10: Add moe_backend user override support via map_mxfp4_backend()
  (matches FP8/NvFP4 oracle pattern)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…sues

- Add kMxfp4Static to MarlinExperts._supports_quant_scheme() so the
  MXFP4 oracle can select Marlin for MXFP4 quantized models
- Fix CompressedTensorsW4A4Mxfp4MoEMethod to use MXFP4 oracle
  (Mxfp4MoeBackend, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config)
  instead of incorrectly routing through the NvFP4 oracle
- Derive weight dimensions from tensor shapes in
  prepare_moe_fp4_layer_for_marlin() to handle padded/rounded sizes
- Rename self.moe_mk -> self.moe_kernel for consistency with
  FusedMoEMethodBase convention
- Remove dead make_mxfp4_moe_quant_config from oracle/nvfp4.py
- Update test_mxfp4_triton_ep.py for new oracle API
- Update docs to reference new trtllm_mxfp4_moe.py module path

Verified on H200 (GPQA on gpt-oss-120b, target ~65.3%):
  Marlin  tp=1: 66.29%  tp=2: 65.78%  dp=2+ep: fail (pre-existing)
  Triton  tp=1: 64.46%  tp=2: 66.48%  dp=2+ep: fail (pre-existing)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye force-pushed the mxfp4_oracle_rebased branch from 3231979 to a7cd3b5 Compare March 16, 2026 02:43
zyongye added 4 commits March 17, 2026 06:01
…FP8-style for-loop selection

- Revert test_mxfp4_triton_ep.py to main branch's TestTritonMoeForwardExpertMap
  (real EP test), drop TestMxfp4TritonIsMonolithic (dead oracle APIs)
- Add _get_priority_backends and _backend_activation_key helpers for
  MXFP8 activation key mapping (kMxfp8Dynamic for W4A8 backends)
- Replace if/elif platform chain with FP8-style for-loop selection
  that delegates to is_supported_config on each kernel class
- Priority order: TRTLLM BF16 > Triton > CUTLASS BF16 > Triton Unfused > Marlin
- Add temporary test_mxfp4_oracle_selection.py to verify selection logic

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye marked this pull request as ready for review March 17, 2026 16:41
@zyongye zyongye changed the title Mxfp4 oracle rebased [MoE Refactor] Mxfp4 oracle rebased Mar 17, 2026
@zyongye
Copy link
Member Author

zyongye commented Mar 17, 2026

/gemini review

Wire the CK/AITER MXFP4 MoE backend through the oracle selection
framework, restoring ROCm GFX950 support that was dropped during the
oracle refactor.

Oracle changes (oracle/mxfp4.py):
- Add CK enum, backend_to_kernel_cls mapping, user-facing "ck" alias
- Add CK to priority list (TRTLLM > CK > Triton)
- Add CK weight conversion: de-interleave w13 rows, view as
  float4_e2m1fn_x2, AITER shuffle_weight/scale_a16w4, permute bias
- Add CK to mxfp4_w4a16 quant config path

AiterExperts changes (rocm_aiter_fused_moe.py):
- Add kMxfp4Static to _supports_quant_scheme (resolves TODO)
- Add SWIGLUOAI to _supports_activation
- Handle SWIGLUOAI in rocm_aiter_fused_experts via AITER swiglu type
- Pass bias1/bias2 from quant_config for MXFP4 path
- Use BLOCK_1X32 quant_method for both w4a4 and w4a16 MXFP4 paths

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Mar 18, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 18, 2026
@zyongye zyongye added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
zyongye added 2 commits March 18, 2026 04:16
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@mergify
Copy link

mergify bot commented Mar 18, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zyongye.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 18, 2026
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@mergify mergify bot removed the needs-rebase label Mar 18, 2026
zyongye added 8 commits March 18, 2026 20:28
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1 requires Blackwell (SM100+),
so only set it for gpt-oss-20b when running on Blackwell devices.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
is_device_capability(90) checks for exactly (9,0), but ROCm gfx942
maps to (9,4) and gfx950 to (9,5), causing all MXFP4 MoE backends to
be rejected. Use range check (9,0) <= cap < (11,0) to match the
oracle's own triton_kernels_supported assessment.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
The rocm_aiter_fused_moe module cannot be imported in non-ROCm doc
build environments, causing mkdocs strict mode to fail on broken
cross-references for AiterExperts and rocm_aiter_fused_experts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Copy link

@yzong-rh yzong-rh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! LGTM if @mgoin or @robertgshaw2-redhat are ok as well.

Maybe add gpt-oss to tests/evals/gsm8k/configs/moe-refactor? Tests already exist in tests/evals/gpt_oss/configs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Remove those now that they are unused?

Or even better, remove this function and call maybe_roundup_layer_hidden_size directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will keep this and just delete is_mxfp4_quant to keep the interface. Because maybe_roundup_layer_hidden_size is from prepare and finalize. and the best way is to move that in to the prepare_finalize itself.

elif activation == MoEActivation.GELU:
activation_method = ActivationMethod.GELU
elif activation == MoEActivation.SWIGLUOAI:
activation_method = rocm_aiter_ops.get_aiter_activation_type("swiglu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Maybe add to class ActivationMethod(IntEnum)?

self.mxfp4_backend: Mxfp4MoeBackend | None = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: QuarkOCP_MX_MoEMethod doesn't use FusedMoEKernel yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it worth another PR since there is a different mapping on the other file

vllm/vllm/_aiter_ops.py

Lines 1083 to 1090 in b55156e

mapping = {
"none": ActivationType.No,
"no": ActivationType.No,
"silu": ActivationType.Silu,
"gelu": ActivationType.Gelu,
"swiglu": ActivationType.Swiglu,
}
return mapping.get(name)

logger.info_once("Using Marlin backend for mxfp4 lora")
return Mxfp4MoeBackend.MARLIN, backend_to_kernel_cls(Mxfp4MoeBackend.MARLIN)[0]

# Issue 7 fix: use high-level property (consistent with FP8 oracle)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Remove commet?

def _make_log_backend(backend: Mxfp4MoeBackend):
return f"Using '{backend.value}' Mxfp4 MoE backend."

# Issue 8 fix: consistent quoting in log messages

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Remove comment?

Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
triton_kernels_supported = has_triton_kernels() and (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should rely on FusedMoEExperts::is_supported_config and not duplicate triton kernel support here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will keep this the same for now just to be consistent with fp8.

if current_platform.is_xpu():
backend = Mxfp4MoeBackend.XPU
logger.info_once(_make_log_backend(backend))
return backend, None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: TODO: Create XPU monolithic expert.

zyongye added 2 commits March 19, 2026 23:30
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, just comments on a few things that seem to be missing

Comment on lines +246 to +248
# MXFP4-specific rounding is now handled in the MXFP4 oracle
# (mxfp4_round_up_hidden_size_and_intermediate_size) called from
# Mxfp4MoEMethod.__init__().
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cruft

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Mar 20, 2026
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 20, 2026
@mgoin mgoin enabled auto-merge (squash) March 20, 2026 21:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation gpt-oss Related to GPT-OSS models nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Todo
Status: Ready
Status: Ready

Development

Successfully merging this pull request may close these issues.

3 participants