Skip to content

[MoE Refactor] Migrate Unquantized to Full Oracle Flow#36286

Open
yzong-rh wants to merge 25 commits intovllm-project:mainfrom
yzong-rh:yzong-rh/moe-unquantized-refactor
Open

[MoE Refactor] Migrate Unquantized to Full Oracle Flow#36286
yzong-rh wants to merge 25 commits intovllm-project:mainfrom
yzong-rh:yzong-rh/moe-unquantized-refactor

Conversation

@yzong-rh
Copy link

@yzong-rh yzong-rh commented Mar 6, 2026

Purpose

Migrate the unquantized MoE (BF16) code path from the legacy kernel initialization pattern to the modern modular pattern already used by FP8 and NvFP4.

The CPU backend is not migrated and remains on the old path due to interface differences (see below).

Background

There are unquantized

  • Monolithic backends (CPU, FlashInfer TRTLLM)
  • Non-monolithic backends (Triton, AITER, FlashInfer CUTLASS, XPU)

In the old path

  • Monolithic backends bypassed the oracle via UNSUPPORTED_BACKEND and were implemented in forward_monolithic_{cuda|cpu}.
  • Non-monolithic backends got a hardcoded NoDPEP prepare/finalize, which may be swapped for an appropriate prepare/finalize in prepare_communication_buffer_for_model (after weight loading and process_weights_after_loading).

This PR

Lifecycle: Old Path vs New Path

Non-monolithic (Triton, AITER, FlashInfer CUTLASS, XPU)

OLD                                          NEW
───                                          ───
UnquantizedFusedMoEMethod                    UnquantizedFusedMoEMethod
  ::process_weights_after_loading()            ::process_weights_after_loading()
  └─ UnquantizedFusedMoEMethod                 └─ UnquantizedFusedMoEMethod
       ::_setup_kernel()                            ::_setup_kernel()
       └─ make_unquantized_moe_kernel()             └─ make_unquantized_moe_kernel()
            hardcodes NoDPEP                             get appropriate PrepareAndFinalize
       └─ stored in self.kernel                     └─ stored in self.moe_kernel
       ∴ supports_internal_mk = False               ∴ supports_internal_mk = True

DeviceCommunicatorBase                       DeviceCommunicatorBase
  ::prepare_communication_buffer_for_model()   ::prepare_communication_buffer_for_model()
  └─ FusedMoE::maybe_init_modular_kernel()     └─ FusedMoE::maybe_init_modular_kernel()
       └─ RUNS: may wrap quant method               └─ NO-OP (early return supports_internal_mk = True)
          with FusedMoEModularMethod

FusedMoE::forward()                          FusedMoE::forward()
  └─ MoERunner::forward()                     └─ MoERunner::forward()
       ├─ router selects topK                      ├─ router selects topK
       ├─ if dp>1: runner dispatches/combines      ├─ (kernel handles dispatch internally)
       └─ UnquantizedFusedMoEMethod::apply()       └─ UnquantizedFusedMoEMethod::apply()
            └─ FusedMoEKernel::apply()                  └─ FusedMoEKernel::apply()

Monolithic — GPU (FlashInfer TRTLLM)

OLD                                          NEW
───                                          ───
UnquantizedFusedMoEMethod                    UnquantizedFusedMoEMethod
  ::process_weights_after_loading()            ::process_weights_after_loading()
  └─ UnquantizedFusedMoEMethod                 └─ UnquantizedFusedMoEMethod
       ::_setup_kernel()                            ::_setup_kernel()
       └─ SKIPPED (in UNSUPPORTED_BACKEND)          └─ make_unquantized_moe_kernel()
       └─ _is_monolithic set manually                    builds FusedMoEKernel
       └─ self.moe_kernel = None                         with TrtLlmBf16Experts
                                                    └─ stored in self.moe_kernel
                                                    ∴ supports_internal_mk = True

DeviceCommunicatorBase                       DeviceCommunicatorBase
  ::prepare_communication_buffer_for_model()   ::prepare_communication_buffer_for_model()
  └─ FusedMoE::maybe_init_modular_kernel()     └─ FusedMoE::maybe_init_modular_kernel()
       └─ NO-OP (is_monolithic = True)              └─ NO-OP (is_monolithic = True)

FusedMoE::forward()                          FusedMoE::forward()
  └─ MoERunner::forward()                     └─ MoERunner::forward()
       └─ UnquantizedFusedMoEMethod                └─ UnquantizedFusedMoEMethod
            ::apply_monolithic()                        ::apply_monolithic()
            └─ forward_monolithic_cuda()                └─ FusedMoEKernel
               (hand-written, bypasses oracle)               ::apply_monolithic()

Monolithic — CPU (not migrated, unchanged)

UnquantizedFusedMoEMethod
  ::process_weights_after_loading()
  └─ UnquantizedFusedMoEMethod::_setup_kernel()
       └─ SKIPPED (backend == CPU)
       └─ self.moe_kernel = None, is_monolithic = True (hardcoded)
       └─ self.cpu_fused_moe set up directly

DeviceCommunicatorBase
  ::prepare_communication_buffer_for_model()
  └─ FusedMoE::maybe_init_modular_kernel()
       └─ NO-OP (is_monolithic = True)

FusedMoE::forward()
  └─ MoERunner::forward()
       └─ UnquantizedFusedMoEMethod::apply_monolithic()
            └─ self.cpu_fused_moe(...)  (bypasses oracle)

Changes

New file: experts/trtllm_bf16_moe.pyTrtLlmBf16Experts, a FusedMoEExpertsMonolithic subclass wrapping the flashinfer.fused_moe.trtllm_bf16_moe call.

oracle/unquantized.py:

  • select_unquantized_moe_backend now returns (backend, experts_cls) instead of just backend, mirroring FP8. CPU returns (CPU, None).
  • Removed UNSUPPORTED_BACKEND. Added BATCHED_TRITON enum variant and backend_to_kernel_cls mapping.
  • Backend selection uses FP8's priority-list fallback pattern: iterate candidates, call is_supported_config, log and skip unsupported ones.
  • make_unquantized_moe_kernel now calls maybe_make_prepare_finalize(allow_new_interface=True) instead of hardcoding NoDPEP, and always returns a FusedMoEKernel.
  • FlashInfer TRTLLM weight preprocessing (w13 half-swap + block layout) moved into convert_to_unquantized_kernel_format.

unquantized_fused_moe_method.py:

  • __init__ stores experts_cls from the backend selector. Removed self.kernel, _is_monolithic, and _select_monolithic.
  • _setup_kernel stores the kernel in self.moe_kernel (not self.kernel), making supports_internal_mk=True and causing maybe_init_modular_kernel to no-op.
  • is_monolithic returns True for CPU, delegates to super() otherwise.
  • forward_native and forward_cuda use self.moe_kernel.apply().
  • apply_monolithic dispatches CPU to self.cpu_fused_moe, all others to self.moe_kernel.apply_monolithic().
  • Removed forward_monolithic_cuda, select_gemm_impl, and the FlashInfer TRTLLM branch from process_weights_after_loading.

Other cleanups:

  • Removed dead rocm_aiter_moe_enabled condition.
  • TPU/OOT backends replaced with NONE (mirrors [Bugfix][TPU] Return a Default fp8 MoE Backend #32908).
  • Added guard against shared_experts passed to FusedMoEExpertsMonolithic.
  • Strengthened FlashInfer backend platform checks.

The CPU backend (CPUFusedMOE/SGLFusedMOE) stays on the old monolithic path because it has three interface differences that make a clean migration non-trivial:

  1. It performs its own routing with parameters not in FusedMoEConfig (renormalize, scoring_func, custom_routing_function).
  2. It selects between three sub-strategies (SGL, Grouped GEMM, Torch fallback) at weight-loading time based on hardware ISA detection.
  3. The Torch fallback stores per-expert closures on the layer object, unlike the standard apply(w1, w2, ...) interface.

Test Plan

Integration tests:

moe-refactor/Mixtral-8x7B-BF16-triton.yaml
moe-refactor/Mixtral-8x7B-BF16-fi-cutlass.yaml
moe-refactor/Qwen3-30B-A3B-BF16-triton.yaml
moe-refactor/Qwen3-30B-A3B-BF16-fi-cutlass.yaml

Updated unit tests:

pytest -v -s tests/kernels/moe/test_unquantized_backend_selection.py
pytest -v -s tests/kernels/moe/test_moe.py::test_unquantized_bf16_flashinfer_trtllm_backend

Other unit tests:

pytest -v -s tests/kernels/moe/test_flashinfer.py::test_convert_moe_weights_to_flashinfer_trtllm_block_layout
pytest -v -s tests/kernels/moe/test_moe.py::test_fused_moe
pytest -v -s tests/kernels/moe/test_moe.py::test_naive_block_assignment_moe
pytest -v -s tests/distributed/test_expert_parallel.py -k Mixtral
pytest -v -s tests/distributed/test_eplb_fused_moe_layer.py

Test Result

Config Expected Accuracy Measured Accuracy
Mixtral-8x7B-BF16-triton 0.5800 0.5572
Mixtral-8x7B-BF16-fi-cutlass 0.5800 0.5686
Qwen3-30B-A3B-BF16-triton.yaml 0.8800 0.8931
Qwen3-30B-A3B-BF16-fi-cutlass.yaml 0.8800 0.8886

All unit tests pass on B200 machine

cc @robertgshaw2-redhat @bnellnm


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@yzong-rh yzong-rh changed the title [MoE Refactor] Migrate UnquantizedFusedMoEMethod and oracle to MK flow#1 [MoE Refactor] Migrate UnquantizedFusedMoEMethod and oracle to MK flow Mar 6, 2026
@mergify mergify bot added the nvidia label Mar 6, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring of the Mixture of Experts (MoE) infrastructure, migrating the unquantized MoE method and its selection oracle to the new modular kernel (MK) flow. This greatly improves the modularity, maintainability, and extensibility of the MoE implementation. Key changes include the introduction of TrtLlmBf16Experts as a new modular kernel, a completely rewritten and more robust backend selection oracle for unquantized MoE, and strengthened platform support checks for various FlashInfer backends. The code is well-structured and successfully moves towards a more unified MoE framework. I've identified one area for improvement to ensure the correctness of the new backend's support check.

("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
# TODO: Add check for `trtllm_bf16_moe`?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO comment highlights a missing check for trtllm_bf16_moe. The has_flashinfer_trtllm_fused_moe function is used by TrtLlmBf16Experts to determine if the kernel is supported. Without this check, the system might incorrectly report support for the bf16 TRT-LLM kernel, potentially leading to a runtime error if the trtllm_bf16_moe function is missing from the flashinfer library. Please add the check to ensure correctness.

Suggested change
# TODO: Add check for `trtllm_bf16_moe`?
("flashinfer.fused_moe", "trtllm_bf16_moe"),

Copy link
Author

@yzong-rh yzong-rh Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO is intentional. I'd like to get some eyes on this before adding it or removing the TODO.

@github-actions
Copy link

github-actions bot commented Mar 6, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@yzong-rh yzong-rh marked this pull request as ready for review March 7, 2026 01:29
w2_weight,
)

return w13_weight, w2_weight
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need to ensure these weights are contiguous()

Comment on lines +566 to +568
if (
moe_config.moe_parallel_config.use_all2all_kernels and not is_monolithic
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check because the monolithic path doesn't implement shared_experts? If so, do we have an assert for that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, monolithic path does not support shared_experts. We do an assert within FusedMoEKernel when we use the monolithic implementation.

Technically, we assert that inplace is False as well, but from Rob's comment here, it seems there are plans to support inplace so I didn't add a check there.

Copy link
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Just had one minor question.

bnellnm added a commit to neuralmagic/vllm that referenced this pull request Mar 10, 2026
…land

Signed-off-by: Bill Nell <bnell@redhat.com>
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed this issue here: #36061, so we can remove this once #36061 lands

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, you can remove this now, its not relevant to this PR

shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_all2all_kernels
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, nice catch on this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we can remove it since its irrelevant to this PR

shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
assert self.unquantized_backend != UnquantizedMoeBackend.NONE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what causes this to return NONE for backend?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NONE backend is returned for TPU and OOT.


# --8<-- [start:unquantized_fused_moe]
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self, we should make this not be a CustomOp in the future

def _supports_current_device() -> bool:
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
return (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good fix, but irrelevant to this PR. Please remove it and we can add it in a separate PR

"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
return (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good fix, but irrelevant to this PR. Please remove it and we can add it in another Pr

p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)
return (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good fix, but irrelevant for this PR, please remove it and open up another PR

("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
# TODO: Add check for `trtllm_bf16_moe`?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, will do it in the other PR with the others



class UnquantizedMoeBackend(Enum):
NONE = "NONE"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update this to match the style of the other oracles?

if current_platform.is_out_of_tree():
backend = UnquantizedMoeBackend.OOT
for backend in AVAILABLE_BACKENDS:
backend = _maybe_swap_to_batched_variant(backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if this function did not exist. we now have 2 spots where we set AVAILABLE_BACKENDS. I would suggest just having BATCHED_TRITON in the lists above

Signed-off-by: Yifan Zong <yzong@redhat.com>
1. override and throw in `select_gemm_impl`
2. remove `NONE` backend and throw early on TPU/OOT platforms
3. remove _maybe_swap_to_batched_variant so that AVAILABLE_BACKENDS is set in one location
4. Use use_deepep_ll_kernels instead of use_all2all_kernels

Signed-off-by: Yifan Zong <yzong@redhat.com>
Copy link
Author

@yzong-rh yzong-rh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Botched a rebase and accidentally notified everyone. Sorry!

Sorry about the noise. No further action is required from you.

Further commits are added to #36732 .

Signed-off-by: Yifan Zong <yzong@redhat.com>
Copy link
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as long as @robertgshaw2-redhat 's comments are addressed.

robertgshaw2-redhat and others added 4 commits March 20, 2026 17:29
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
@yzong-rh
Copy link
Author

WIP triage of CI failures: https://gist.github.com/yzong-rh/134ce7b202a35800d90a5f41c8318969
Not sure why unquantized CUTLASS isn't working.

@robertgshaw2-redhat
Copy link
Collaborator

I broke everything!

Robert Shaw added 9 commits March 21, 2026 10:18
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs

Projects

Status: Todo
Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

4 participants