Skip to content

[MoE Refactor] Migrate UnquantizedFusedMoEMethod and oracle to MK flow#36732

Closed
yzong-rh wants to merge 12 commits intovllm-project:mainfrom
yzong-rh:yzong-rh/moe-unquantized-refactor-2
Closed

[MoE Refactor] Migrate UnquantizedFusedMoEMethod and oracle to MK flow#36732
yzong-rh wants to merge 12 commits intovllm-project:mainfrom
yzong-rh:yzong-rh/moe-unquantized-refactor-2

Conversation

@yzong-rh
Copy link
Contributor

@yzong-rh yzong-rh commented Mar 11, 2026

Continuation of #36286

Purpose

Migrate the unquantized MoE (BF16) code path from the legacy kernel initialization pattern to the modern modular pattern already used by FP8 and NvFP4.

The CPU backend is not migrated and remains on the old path due to interface differences (see below).

Background

There are unquantized

  • Monolithic backends (CPU, FlashInfer TRTLLM)
  • Non-monolithic backends (Triton, AITER, FlashInfer CUTLASS, XPU)

In the old path

  • Monolithic backends bypassed the oracle via UNSUPPORTED_BACKEND and were implemented in forward_monolithic_{cuda|cpu}.
  • Non-monolithic backends got a hardcoded NoDPEP prepare/finalize, which may be swapped for an appropriate prepare/finalize in prepare_communication_buffer_for_model (after weight loading and process_weights_after_loading).

This PR

Lifecycle: Old Path vs New Path

Non-monolithic (Triton, AITER, FlashInfer CUTLASS, XPU)

OLD                                          NEW
───                                          ───
UnquantizedFusedMoEMethod                    UnquantizedFusedMoEMethod
  ::process_weights_after_loading()            ::process_weights_after_loading()
  └─ UnquantizedFusedMoEMethod                 └─ UnquantizedFusedMoEMethod
       ::_setup_kernel()                            ::_setup_kernel()
       └─ make_unquantized_moe_kernel()             └─ make_unquantized_moe_kernel()
            hardcodes NoDPEP                             get appropriate PrepareAndFinalize
       └─ stored in self.kernel                     └─ stored in self.moe_kernel
       ∴ supports_internal_mk = False               ∴ supports_internal_mk = True

DeviceCommunicatorBase                       DeviceCommunicatorBase
  ::prepare_communication_buffer_for_model()   ::prepare_communication_buffer_for_model()
  └─ FusedMoE::maybe_init_modular_kernel()     └─ FusedMoE::maybe_init_modular_kernel()
       └─ RUNS: may wrap quant method               └─ NO-OP (early return supports_internal_mk = True)
          with FusedMoEModularMethod

FusedMoE::forward()                          FusedMoE::forward()
  └─ MoERunner::forward()                     └─ MoERunner::forward()
       ├─ router selects topK                      ├─ router selects topK
       ├─ if dp>1: runner dispatches/combines      ├─ (kernel handles dispatch internally)
       └─ UnquantizedFusedMoEMethod::apply()       └─ UnquantizedFusedMoEMethod::apply()
            └─ FusedMoEKernel::apply()                  └─ FusedMoEKernel::apply()

Monolithic — GPU (FlashInfer TRTLLM)

OLD                                          NEW
───                                          ───
UnquantizedFusedMoEMethod                    UnquantizedFusedMoEMethod
  ::process_weights_after_loading()            ::process_weights_after_loading()
  └─ UnquantizedFusedMoEMethod                 └─ UnquantizedFusedMoEMethod
       ::_setup_kernel()                            ::_setup_kernel()
       └─ SKIPPED (in UNSUPPORTED_BACKEND)          └─ make_unquantized_moe_kernel()
       └─ _is_monolithic set manually                    builds FusedMoEKernel
       └─ self.moe_kernel = None                         with TrtLlmBf16Experts
                                                    └─ stored in self.moe_kernel
                                                    ∴ supports_internal_mk = True

DeviceCommunicatorBase                       DeviceCommunicatorBase
  ::prepare_communication_buffer_for_model()   ::prepare_communication_buffer_for_model()
  └─ FusedMoE::maybe_init_modular_kernel()     └─ FusedMoE::maybe_init_modular_kernel()
       └─ NO-OP (is_monolithic = True)              └─ NO-OP (is_monolithic = True)

FusedMoE::forward()                          FusedMoE::forward()
  └─ MoERunner::forward()                     └─ MoERunner::forward()
       └─ UnquantizedFusedMoEMethod                └─ UnquantizedFusedMoEMethod
            ::apply_monolithic()                        ::apply_monolithic()
            └─ forward_monolithic_cuda()                └─ FusedMoEKernel
               (hand-written, bypasses oracle)               ::apply_monolithic()

Monolithic — CPU (not migrated, unchanged)

UnquantizedFusedMoEMethod
  ::process_weights_after_loading()
  └─ UnquantizedFusedMoEMethod::_setup_kernel()
       └─ SKIPPED (backend == CPU)
       └─ self.moe_kernel = None, is_monolithic = True (hardcoded)
       └─ self.cpu_fused_moe set up directly

DeviceCommunicatorBase
  ::prepare_communication_buffer_for_model()
  └─ FusedMoE::maybe_init_modular_kernel()
       └─ NO-OP (is_monolithic = True)

FusedMoE::forward()
  └─ MoERunner::forward()
       └─ UnquantizedFusedMoEMethod::apply_monolithic()
            └─ self.cpu_fused_moe(...)  (bypasses oracle)

Changes

New file: experts/trtllm_bf16_moe.pyTrtLlmBf16Experts, a FusedMoEExpertsMonolithic subclass wrapping the flashinfer.fused_moe.trtllm_bf16_moe call.

oracle/unquantized.py:

  • select_unquantized_moe_backend now returns (backend, experts_cls) instead of just backend, mirroring FP8. CPU returns (CPU, None).
  • Removed UNSUPPORTED_BACKEND. Added BATCHED_TRITON enum variant and backend_to_kernel_cls mapping.
  • Backend selection uses FP8's priority-list fallback pattern: iterate candidates, call is_supported_config, log and skip unsupported ones.
  • make_unquantized_moe_kernel now calls maybe_make_prepare_finalize(allow_new_interface=True) instead of hardcoding NoDPEP, and always returns a FusedMoEKernel.
  • FlashInfer TRTLLM weight preprocessing (w13 half-swap + block layout) moved into convert_to_unquantized_kernel_format.

unquantized_fused_moe_method.py:

  • __init__ stores experts_cls from the backend selector. Removed self.kernel, _is_monolithic, and _select_monolithic.
  • _setup_kernel stores the kernel in self.moe_kernel (not self.kernel), making supports_internal_mk=True and causing maybe_init_modular_kernel to no-op.
  • is_monolithic returns True for CPU, delegates to super() otherwise.
  • forward_native and forward_cuda use self.moe_kernel.apply().
  • apply_monolithic dispatches CPU to self.cpu_fused_moe, all others to self.moe_kernel.apply_monolithic().
  • Removed forward_monolithic_cuda, select_gemm_impl, and the FlashInfer TRTLLM branch from process_weights_after_loading.

Other cleanups:

  • Removed dead rocm_aiter_moe_enabled condition.
  • TPU/OOT backends replaced with NONE (mirrors [Bugfix][TPU] Return a Default fp8 MoE Backend #32908).
  • Added guard against shared_experts passed to FusedMoEExpertsMonolithic.
  • Strengthened FlashInfer backend platform checks.

The CPU backend (CPUFusedMOE/SGLFusedMOE) stays on the old monolithic path because it has three interface differences that make a clean migration non-trivial:

  1. It performs its own routing with parameters not in FusedMoEConfig (renormalize, scoring_func, custom_routing_function).
  2. It selects between three sub-strategies (SGL, Grouped GEMM, Torch fallback) at weight-loading time based on hardware ISA detection.
  3. The Torch fallback stores per-expert closures on the layer object, unlike the standard apply(w1, w2, ...) interface.

Test Plan

Integration tests:

moe-refactor/Mixtral-8x7B-BF16-triton.yaml
moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml
moe-refactor/Qwen3-30B-A3B-BF16-triton.yaml
moe-refactor/Qwen3-30B-A3B-BF16-fi-cutlass.yaml

Updated unit tests:

pytest -v -s tests/kernels/moe/test_unquantized_backend_selection.py
pytest -v -s tests/kernels/moe/test_moe.py::test_unquantized_bf16_flashinfer_trtllm_backend

Other unit tests:

pytest -v -s tests/kernels/moe/test_flashinfer.py::test_convert_moe_weights_to_flashinfer_trtllm_block_layout
pytest -v -s tests/kernels/moe/test_moe.py::test_fused_moe
pytest -v -s tests/kernels/moe/test_moe.py::test_naive_block_assignment_moe
pytest -v -s tests/distributed/test_expert_parallel.py -k Mixtral
pytest -v -s tests/distributed/test_eplb_fused_moe_layer.py

Test Result

Config Expected Accuracy Measured Accuracy
Mixtral-8x7B-BF16-triton 0.5800 0.5572
Mixtral-8x7B-BF16-fi-cutlass 0.5800 0.5686
Qwen3-30B-A3B-BF16-triton.yaml 0.8800 0.8931
Qwen3-30B-A3B-BF16-fi-cutlass.yaml 0.8800 0.8886

All unit tests pass on B200 machine

cc @robertgshaw2-redhat @bnellnm


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

yzong-rh added 11 commits March 10, 2026 21:01
Signed-off-by: Yifan Zong <yzong@redhat.com>
TPU and OOT platforms do not have in-tree unquantized MoE kernels
and rely on OOT plugins to replace UnquantizedFusedMoEMethod via
CustomOp.register_oot. Having dedicated enum values for them was
misleading — they could never produce a working kernel and would
silently fail at runtime with an opaque AssertionError.

Replace them with a single UnquantizedMoeBackend.NONE value
(mirroring Fp8MoeBackend.NONE) and add an assertion in
UnquantizedFusedMoEMethod.__init__ to fail fast if no OOT plugin
is registered. Also fix the if/elif chain in
select_unquantized_moe_backend to prevent accidental overwrites
across platform checks.

Signed-off-by: Yifan Zong <yzong@redhat.com>
…ithic, which do not support them

Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Yifan Zong <yzong@redhat.com>
…d oracle

Signed-off-by: Yifan Zong <yzong@redhat.com>

[MoE] convert_to_unquantized_kernel_format weight source consistency

Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Yifan Zong <yzong@redhat.com>
…rrect Enums

Signed-off-by: Yifan Zong <yzong@redhat.com>
This reverts commit be21911.

Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Yifan Zong <yzong@redhat.com>
1. override and throw in `select_gemm_impl`
2. remove `NONE` backend and throw early on TPU/OOT platforms
3. remove _maybe_swap_to_batched_variant so that AVAILABLE_BACKENDS is set in one location
4. Use use_deepep_ll_kernels instead of use_all2all_kernels

Signed-off-by: Yifan Zong <yzong@redhat.com>
@yzong-rh yzong-rh marked this pull request as draft March 11, 2026 02:37
@yzong-rh yzong-rh changed the title [MoE] Refactor unquantized MoE backend oracle and kernel initialization [MoE Refactor] Migrate UnquantizedFusedMoEMethod and oracle to MK flow Mar 11, 2026
@mergify
Copy link

mergify bot commented Mar 11, 2026

⚠️ The sha of the head commit of this PR conflicts with #36286. Mergify cannot evaluate rules on this PR. Once #36286 is merged or closed, Mergify will resume processing this PR. ⚠️

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring of the unquantized MoE backend selection and kernel initialization logic. The changes greatly improve modularity and maintainability by replacing complex conditional chains with a priority-based oracle and dedicated expert classes. The new structure is much cleaner and more extensible. I've found one critical typo that would prevent a backend from being selected correctly, which I've commented on. Otherwise, the changes look excellent.

…quantized.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: yzong-rh <yzong@redhat.com>
@yzong-rh
Copy link
Contributor Author

Further work ended up being done on original branch

@yzong-rh yzong-rh closed this Mar 20, 2026
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

1 participant