[MoE Refactor] Migrate Unquantized to Full Oracle Flow#36286
[MoE Refactor] Migrate Unquantized to Full Oracle Flow#36286yzong-rh wants to merge 25 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request is a significant and well-executed refactoring of the Mixture of Experts (MoE) infrastructure, migrating the unquantized MoE method and its selection oracle to the new modular kernel (MK) flow. This greatly improves the modularity, maintainability, and extensibility of the MoE implementation. Key changes include the introduction of TrtLlmBf16Experts as a new modular kernel, a completely rewritten and more robust backend selection oracle for unquantized MoE, and strengthened platform support checks for various FlashInfer backends. The code is well-structured and successfully moves towards a more unified MoE framework. I've identified one area for improvement to ensure the correctness of the new backend's support check.
vllm/utils/flashinfer.py
Outdated
| ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), | ||
| ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), | ||
| ("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"), | ||
| # TODO: Add check for `trtllm_bf16_moe`? |
There was a problem hiding this comment.
The TODO comment highlights a missing check for trtllm_bf16_moe. The has_flashinfer_trtllm_fused_moe function is used by TrtLlmBf16Experts to determine if the kernel is supported. Without this check, the system might incorrectly report support for the bf16 TRT-LLM kernel, potentially leading to a runtime error if the trtllm_bf16_moe function is missing from the flashinfer library. Please add the check to ensure correctness.
| # TODO: Add check for `trtllm_bf16_moe`? | |
| ("flashinfer.fused_moe", "trtllm_bf16_moe"), |
There was a problem hiding this comment.
TODO is intentional. I'd like to get some eyes on this before adding it or removing the TODO.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
| w2_weight, | ||
| ) | ||
|
|
||
| return w13_weight, w2_weight |
There was a problem hiding this comment.
Not sure if we need to ensure these weights are contiguous()
| if ( | ||
| moe_config.moe_parallel_config.use_all2all_kernels and not is_monolithic | ||
| ) |
There was a problem hiding this comment.
Is this check because the monolithic path doesn't implement shared_experts? If so, do we have an assert for that?
There was a problem hiding this comment.
Yeah, monolithic path does not support shared_experts. We do an assert within FusedMoEKernel when we use the monolithic implementation.
Technically, we assert that inplace is False as well, but from Rob's comment here, it seems there are plans to support inplace so I didn't add a check there.
…land Signed-off-by: Bill Nell <bnell@redhat.com>
| shared_experts=( | ||
| shared_experts | ||
| if moe_config.moe_parallel_config.use_all2all_kernels | ||
| if ( |
There was a problem hiding this comment.
actually, you can remove this now, its not relevant to this PR
| shared_experts=( | ||
| shared_experts | ||
| if moe_config.moe_parallel_config.use_all2all_kernels | ||
| if ( |
There was a problem hiding this comment.
ditto, nice catch on this
There was a problem hiding this comment.
but we can remove it since its irrelevant to this PR
| shared_experts_input: torch.Tensor | None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input) | ||
| assert self.unquantized_backend != UnquantizedMoeBackend.NONE |
There was a problem hiding this comment.
what causes this to return NONE for backend?
There was a problem hiding this comment.
NONE backend is returned for TPU and OOT.
|
|
||
| # --8<-- [start:unquantized_fused_moe] | ||
| @CustomOp.register("unquantized_fused_moe") | ||
| class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): |
There was a problem hiding this comment.
note to self, we should make this not be a CustomOp in the future
| def _supports_current_device() -> bool: | ||
| p = current_platform | ||
| return p.is_cuda() and p.is_device_capability_family(100) | ||
| return ( |
There was a problem hiding this comment.
this is a good fix, but irrelevant to this PR. Please remove it and we can add it in a separate PR
| """Supports only Blackwell-family GPUs.""" | ||
| p = current_platform | ||
| return p.is_cuda() and p.is_device_capability_family(100) | ||
| return ( |
There was a problem hiding this comment.
good fix, but irrelevant to this PR. Please remove it and we can add it in another Pr
| p = current_platform | ||
| # Add check flashinfer trtllm is available | ||
| return p.is_cuda() and p.is_device_capability_family(100) | ||
| return ( |
There was a problem hiding this comment.
good fix, but irrelevant for this PR, please remove it and open up another PR
vllm/utils/flashinfer.py
Outdated
| ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), | ||
| ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), | ||
| ("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"), | ||
| # TODO: Add check for `trtllm_bf16_moe`? |
There was a problem hiding this comment.
Sg, will do it in the other PR with the others
|
|
||
|
|
||
| class UnquantizedMoeBackend(Enum): | ||
| NONE = "NONE" |
There was a problem hiding this comment.
can you update this to match the style of the other oracles?
| if current_platform.is_out_of_tree(): | ||
| backend = UnquantizedMoeBackend.OOT | ||
| for backend in AVAILABLE_BACKENDS: | ||
| backend = _maybe_swap_to_batched_variant(backend) |
There was a problem hiding this comment.
I would prefer if this function did not exist. we now have 2 spots where we set AVAILABLE_BACKENDS. I would suggest just having BATCHED_TRITON in the lists above
Signed-off-by: Yifan Zong <yzong@redhat.com>
1. override and throw in `select_gemm_impl` 2. remove `NONE` backend and throw early on TPU/OOT platforms 3. remove _maybe_swap_to_batched_variant so that AVAILABLE_BACKENDS is set in one location 4. Use use_deepep_ll_kernels instead of use_all2all_kernels Signed-off-by: Yifan Zong <yzong@redhat.com>
457b939 to
1e36445
Compare
Signed-off-by: Yifan Zong <yzong@redhat.com>
bnellnm
left a comment
There was a problem hiding this comment.
LGTM as long as @robertgshaw2-redhat 's comments are addressed.
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
|
WIP triage of CI failures: https://gist.github.com/yzong-rh/134ce7b202a35800d90a5f41c8318969 |
|
I broke everything! |
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Purpose
Migrate the unquantized MoE (BF16) code path from the legacy kernel initialization pattern to the modern modular pattern already used by FP8 and NvFP4.
The CPU backend is not migrated and remains on the old path due to interface differences (see below).
Background
There are unquantized
In the old path
UNSUPPORTED_BACKENDand were implemented inforward_monolithic_{cuda|cpu}.NoDPEPprepare/finalize, which may be swapped for an appropriate prepare/finalize inprepare_communication_buffer_for_model(after weight loading andprocess_weights_after_loading).This PR
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.pytovllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py. MirrorsTrtLlmFp8Expertsand [MoE Refactor] Create MK for TRTLLM Kernels #32564Lifecycle: Old Path vs New Path
Non-monolithic (Triton, AITER, FlashInfer CUTLASS, XPU)
Monolithic — GPU (FlashInfer TRTLLM)
Monolithic — CPU (not migrated, unchanged)
Changes
New file:
experts/trtllm_bf16_moe.py—TrtLlmBf16Experts, aFusedMoEExpertsMonolithicsubclass wrapping theflashinfer.fused_moe.trtllm_bf16_moecall.oracle/unquantized.py:select_unquantized_moe_backendnow returns(backend, experts_cls)instead of justbackend, mirroring FP8. CPU returns(CPU, None).UNSUPPORTED_BACKEND. AddedBATCHED_TRITONenum variant andbackend_to_kernel_clsmapping.is_supported_config, log and skip unsupported ones.make_unquantized_moe_kernelnow callsmaybe_make_prepare_finalize(allow_new_interface=True)instead of hardcodingNoDPEP, and always returns aFusedMoEKernel.convert_to_unquantized_kernel_format.unquantized_fused_moe_method.py:__init__storesexperts_clsfrom the backend selector. Removedself.kernel,_is_monolithic, and_select_monolithic._setup_kernelstores the kernel inself.moe_kernel(notself.kernel), makingsupports_internal_mk=Trueand causingmaybe_init_modular_kernelto no-op.is_monolithicreturnsTruefor CPU, delegates tosuper()otherwise.forward_nativeandforward_cudauseself.moe_kernel.apply().apply_monolithicdispatches CPU toself.cpu_fused_moe, all others toself.moe_kernel.apply_monolithic().forward_monolithic_cuda,select_gemm_impl, and the FlashInfer TRTLLM branch fromprocess_weights_after_loading.Other cleanups:
rocm_aiter_moe_enabledcondition.NONE(mirrors [Bugfix][TPU] Return a Default fp8 MoE Backend #32908).shared_expertspassed toFusedMoEExpertsMonolithic.The CPU backend (
CPUFusedMOE/SGLFusedMOE) stays on the old monolithic path because it has three interface differences that make a clean migration non-trivial:FusedMoEConfig(renormalize,scoring_func,custom_routing_function).apply(w1, w2, ...)interface.Test Plan
Integration tests:
Updated unit tests:
Other unit tests:
Test Result
All unit tests pass on B200 machine
cc @robertgshaw2-redhat @bnellnm
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.